Skip to content

Commit 376f758

Browse files
committed
feat(pydantic): added pydantic output schema
1 parent 1d217e4 commit 376f758

23 files changed

+165
-125
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""
2+
Example of Search Graph
3+
"""
4+
5+
import os
6+
from dotenv import load_dotenv
7+
load_dotenv()
8+
9+
from scrapegraphai.graphs import SearchGraph
10+
from scrapegraphai.utils import convert_to_csv, convert_to_json, prettify_exec_info
11+
12+
from pydantic import BaseModel, Field
13+
from typing import List
14+
15+
# ************************************************
16+
# Define the output schema for the graph
17+
# ************************************************
18+
19+
class Dish(BaseModel):
20+
name: str = Field(description="The name of the dish")
21+
description: str = Field(description="The description of the dish")
22+
23+
class Dishes(BaseModel):
24+
dishes: List[Dish]
25+
26+
# ************************************************
27+
# Define the configuration for the graph
28+
# ************************************************
29+
30+
openai_key = os.getenv("OPENAI_APIKEY")
31+
32+
graph_config = {
33+
"llm": {
34+
"api_key": openai_key,
35+
"model": "gpt-3.5-turbo",
36+
},
37+
"max_results": 2,
38+
"verbose": True,
39+
}
40+
41+
# ************************************************
42+
# Create the SearchGraph instance and run it
43+
# ************************************************
44+
45+
search_graph = SearchGraph(
46+
prompt="List me Chioggia's famous dishes",
47+
config=graph_config,
48+
schema=Dishes
49+
)
50+
51+
result = search_graph.run()
52+
print(result)
53+
54+
# ************************************************
55+
# Get graph execution info
56+
# ************************************************
57+
58+
graph_exec_info = search_graph.get_execution_info()
59+
print(prettify_exec_info(graph_exec_info))
60+
61+
# Save to json and csv
62+
convert_to_csv(result, "result")
63+
convert_to_json(result, "result")

examples/openai/smart_scraper_schema_openai.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44

55
import os, json
66
from dotenv import load_dotenv
7+
from pydantic import BaseModel, Field
8+
from typing import List
9+
710
from scrapegraphai.graphs import SmartScraperGraph
811

912
load_dotenv()
@@ -12,22 +15,12 @@
1215
# Define the output schema for the graph
1316
# ************************************************
1417

15-
schema= """
16-
{
17-
"Projects": [
18-
"Project #":
19-
{
20-
"title": "...",
21-
"description": "...",
22-
},
23-
"Project #":
24-
{
25-
"title": "...",
26-
"description": "...",
27-
}
28-
]
29-
}
30-
"""
18+
class Project(BaseModel):
19+
title: str = Field(description="The title of the project")
20+
description: str = Field(description="The description of the project")
21+
22+
class Projects(BaseModel):
23+
projects: List[Project]
3124

3225
# ************************************************
3326
# Define the configuration for the graph
@@ -51,9 +44,9 @@
5144
smart_scraper_graph = SmartScraperGraph(
5245
prompt="List me all the projects with their description",
5346
source="https://perinim.github.io/projects/",
54-
schema=schema,
47+
schema=Projects,
5548
config=graph_config
5649
)
5750

5851
result = smart_scraper_graph.run()
59-
print(json.dumps(result, indent=4))
52+
print(result)

scrapegraphai/graphs/abstract_graph.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
"""
44

55
from abc import ABC, abstractmethod
6-
from typing import Optional
6+
from typing import Optional, Union
77
import uuid
8+
from pydantic import BaseModel
89

910
from langchain_aws import BedrockEmbeddings
1011
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
@@ -62,7 +63,7 @@ class AbstractGraph(ABC):
6263
"""
6364

6465
def __init__(self, prompt: str, config: dict,
65-
source: Optional[str] = None, schema: Optional[str] = None):
66+
source: Optional[str] = None, schema: Optional[BaseModel] = None):
6667

6768
self.prompt = prompt
6869
self.source = source

scrapegraphai/graphs/csv_scraper_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from typing import Optional
6+
from pydantic import BaseModel
67

78
from .base_graph import BaseGraph
89
from .abstract_graph import AbstractGraph
@@ -20,7 +21,7 @@ class CSVScraperGraph(AbstractGraph):
2021
information from web pages using a natural language model to interpret and answer prompts.
2122
"""
2223

23-
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
24+
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
2425
"""
2526
Initializes the CSVScraperGraph with a prompt, source, and configuration.
2627
"""

scrapegraphai/graphs/deep_scraper_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from typing import Optional
6+
from pydantic import BaseModel
67

78
from .base_graph import BaseGraph
89
from .abstract_graph import AbstractGraph
@@ -56,7 +57,7 @@ class DeepScraperGraph(AbstractGraph):
5657
)
5758
"""
5859

59-
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
60+
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
6061

6162
super().__init__(prompt, config, source, schema)
6263

scrapegraphai/graphs/json_scraper_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from typing import Optional
6+
from pydantic import BaseModel
67

78
from .base_graph import BaseGraph
89
from .abstract_graph import AbstractGraph
@@ -44,7 +45,7 @@ class JSONScraperGraph(AbstractGraph):
4445
>>> result = json_scraper.run()
4546
"""
4647

47-
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
48+
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
4849
super().__init__(prompt, config, source, schema)
4950

5051
self.input_key = "json" if source.endswith("json") else "json_dir"

scrapegraphai/graphs/omni_scraper_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from typing import Optional
6+
from pydantic import BaseModel
67

78
from .base_graph import BaseGraph
89
from .abstract_graph import AbstractGraph
@@ -52,7 +53,7 @@ class OmniScraperGraph(AbstractGraph):
5253
)
5354
"""
5455

55-
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
56+
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
5657

5758
self.max_images = 5 if config is None else config.get("max_images", 5)
5859

scrapegraphai/graphs/omni_search_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from copy import copy, deepcopy
66
from typing import Optional
7+
from pydantic import BaseModel
78

89
from .base_graph import BaseGraph
910
from .abstract_graph import AbstractGraph
@@ -43,7 +44,7 @@ class OmniSearchGraph(AbstractGraph):
4344
>>> result = search_graph.run()
4445
"""
4546

46-
def __init__(self, prompt: str, config: dict, schema: Optional[str] = None):
47+
def __init__(self, prompt: str, config: dict, schema: Optional[BaseModel] = None):
4748

4849
self.max_results = config.get("max_results", 3)
4950

scrapegraphai/graphs/pdf_scraper_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from typing import Optional
6+
from pydantic import BaseModel
67

78
from .base_graph import BaseGraph
89
from .abstract_graph import AbstractGraph
@@ -46,7 +47,7 @@ class PDFScraperGraph(AbstractGraph):
4647
>>> result = pdf_scraper.run()
4748
"""
4849

49-
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
50+
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
5051
super().__init__(prompt, config, source, schema)
5152

5253
self.input_key = "pdf" if source.endswith("pdf") else "pdf_dir"

scrapegraphai/graphs/script_creator_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from typing import Optional
6+
from pydantic import BaseModel
67

78
from .base_graph import BaseGraph
89
from .abstract_graph import AbstractGraph
@@ -46,7 +47,7 @@ class ScriptCreatorGraph(AbstractGraph):
4647
>>> result = script_creator.run()
4748
"""
4849

49-
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
50+
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
5051

5152
self.library = config['library']
5253

0 commit comments

Comments
 (0)