Skip to content

Commit 6d33a8a

Browse files
committed
rollback
1 parent 909af8d commit 6d33a8a

File tree

4 files changed

+62
-92
lines changed

4 files changed

+62
-92
lines changed

examples/example.py

Lines changed: 0 additions & 64 deletions
This file was deleted.

scrapegraphai/graphs/abstract_graph.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,21 @@
88
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
99
from langchain_google_genai import GoogleGenerativeAIEmbeddings
1010
from ..helpers import models_tokens
11-
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Anthropic
11+
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Anthropic, DeepSeek
1212
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
1313

14+
from ..helpers import models_tokens
15+
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Anthropic, DeepSeek
16+
17+
1418
class AbstractGraph(ABC):
1519
"""
1620
Scaffolding class for creating a graph representation and executing it.
1721
1822
prompt (str): The prompt for the graph.
1923
source (str): The source of the graph.
2024
config (dict): Configuration parameters for the graph.
25+
schema (str): The schema for the graph output.
2126
llm_model: An instance of a language model client, configured for generating answers.
2227
embedder_model: An instance of an embedding model client,
2328
configured for generating embeddings.
@@ -28,6 +33,7 @@ class AbstractGraph(ABC):
2833
prompt (str): The prompt for the graph.
2934
config (dict): Configuration parameters for the graph.
3035
source (str, optional): The source of the graph.
36+
schema (str, optional): The schema for the graph output.
3137
3238
Example:
3339
>>> class MyGraph(AbstractGraph):
@@ -39,34 +45,42 @@ class AbstractGraph(ABC):
3945
>>> result = my_graph.run()
4046
"""
4147

42-
def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
48+
def __init__(self, prompt: str, config: dict, source: Optional[str] = None, schema: Optional[str] = None):
4349

4450
self.prompt = prompt
4551
self.source = source
4652
self.config = config
53+
self.schema = schema
4754
self.llm_model = self._create_llm(config["llm"], chat=True)
4855
self.embedder_model = self._create_default_embedder(llm_config=config["llm"]
4956
) if "embeddings" not in config else self._create_embedder(
5057
config["embeddings"])
58+
self.verbose = False if config is None else config.get(
59+
"verbose", False)
60+
self.headless = True if config is None else config.get(
61+
"headless", True)
62+
self.loader_kwargs = config.get("loader_kwargs", {})
5163

5264
# Create the graph
5365
self.graph = self._create_graph()
5466
self.final_state = None
5567
self.execution_info = None
5668

5769
# Set common configuration parameters
58-
5970
self.verbose = False if config is None else config.get(
6071
"verbose", False)
6172
self.headless = True if config is None else config.get(
6273
"headless", True)
6374
self.loader_kwargs = config.get("loader_kwargs", {})
6475

65-
common_params = {"headless": self.headless,
66-
67-
"loader_kwargs": self.loader_kwargs,
68-
"llm_model": self.llm_model,
69-
"embedder_model": self.embedder_model}
76+
common_params = {
77+
"headless": self.headless,
78+
"verbose": self.verbose,
79+
"loader_kwargs": self.loader_kwargs,
80+
"llm_model": self.llm_model,
81+
"embedder_model": self.embedder_model
82+
}
83+
7084
self.set_common_params(common_params, overwrite=False)
7185

7286
def set_common_params(self, params: dict, overwrite=False):
@@ -79,7 +93,7 @@ def set_common_params(self, params: dict, overwrite=False):
7993

8094
for node in self.graph.nodes:
8195
node.update_config(params, overwrite)
82-
96+
8397
def _set_model_token(self, llm):
8498

8599
if 'Azure' in str(type(llm)):
@@ -157,7 +171,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
157171
raise KeyError("Model not supported") from exc
158172
return Anthropic(llm_params)
159173
elif "ollama" in llm_params["model"]:
160-
llm_params["model"] = llm_params["model"].split("/")[-1]
174+
llm_params["model"] = llm_params["model"].split("ollama/")[-1]
161175

162176
# allow user to set model_tokens in config
163177
try:
@@ -231,6 +245,8 @@ def _create_default_embedder(self, llm_config=None) -> object:
231245
model="models/embedding-001")
232246
if isinstance(self.llm_model, OpenAI):
233247
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
248+
elif isinstance(self.llm_model, DeepSeek):
249+
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
234250
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
235251
return self.llm_model
236252
elif isinstance(self.llm_model, AzureOpenAI):
@@ -271,7 +287,7 @@ def _create_embedder(self, embedder_config: dict) -> object:
271287
elif "azure" in embedder_config["model"]:
272288
return AzureOpenAIEmbeddings()
273289
elif "ollama" in embedder_config["model"]:
274-
embedder_config["model"] = embedder_config["model"].split("/")[-1]
290+
embedder_config["model"] = embedder_config["model"].split("ollama/")[-1]
275291
try:
276292
models_tokens["ollama"][embedder_config["model"]]
277293
except KeyError as exc:
@@ -297,6 +313,10 @@ def _create_embedder(self, embedder_config: dict) -> object:
297313
except KeyError as exc:
298314
raise KeyError("Model not supported") from exc
299315
return BedrockEmbeddings(client=client, model_id=embedder_config["model"])
316+
else:
317+
raise ValueError(
318+
"Model provided by the configuration not supported")
319+
300320
def get_state(self, key=None) -> dict:
301321
"""""
302322
Get the final state of the graph.
@@ -334,4 +354,4 @@ def run(self) -> str:
334354
"""
335355
Abstract method to execute the graph and return the result.
336356
"""
337-
pass
357+
pass

scrapegraphai/graphs/pdf_scraper_graph.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
FetchNode,
1212
ParseNode,
1313
RAGNode,
14-
GenerateAnswerPDFNode
14+
GenerateAnswerNode
1515
)
1616

1717

@@ -48,7 +48,7 @@ class PDFScraperGraph(AbstractGraph):
4848
"""
4949

5050
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
51-
super().__init__(prompt, config, source)
51+
super().__init__(prompt, config, source, schema)
5252

5353
self.input_key = "pdf" if source.endswith("pdf") else "pdf_dir"
5454

@@ -64,21 +64,41 @@ def _create_graph(self) -> BaseGraph:
6464
input='pdf | pdf_dir',
6565
output=["doc", "link_urls", "img_urls"],
6666
)
67-
generate_answer_node_pdf = GenerateAnswerPDFNode(
67+
parse_node = ParseNode(
68+
input="doc",
69+
output=["parsed_doc"],
70+
node_config={
71+
"chunk_size": self.model_token,
72+
}
73+
)
74+
rag_node = RAGNode(
75+
input="user_prompt & (parsed_doc | doc)",
76+
output=["relevant_chunks"],
77+
node_config={
78+
"llm_model": self.llm_model,
79+
"embedder_model": self.embedder_model,
80+
}
81+
)
82+
generate_answer_node = GenerateAnswerNode(
6883
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
6984
output=["answer"],
7085
node_config={
7186
"llm_model": self.llm_model,
87+
"schema": self.schema,
7288
}
7389
)
7490

7591
return BaseGraph(
7692
nodes=[
7793
fetch_node,
78-
generate_answer_node_pdf,
94+
parse_node,
95+
rag_node,
96+
generate_answer_node,
7997
],
8098
edges=[
81-
(fetch_node, generate_answer_node_pdf)
99+
(fetch_node, parse_node),
100+
(parse_node, rag_node),
101+
(rag_node, generate_answer_node)
82102
],
83103
entry_point=fetch_node
84104
)
@@ -94,4 +114,4 @@ def run(self) -> str:
94114
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
95115
self.final_state, self.execution_info = self.graph.execute(inputs)
96116

97-
return self.final_state.get("answer", "No answer found.")
117+
return self.final_state.get("answer", "No answer found.")

scrapegraphai/graphs/smart_scraper_graph.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,14 @@
22
SmartScraperGraph Module
33
"""
44

5-
from typing import Optional
6-
75
from .base_graph import BaseGraph
8-
from .abstract_graph import AbstractGraph
9-
106
from ..nodes import (
117
FetchNode,
128
ParseNode,
139
RAGNode,
1410
GenerateAnswerNode
1511
)
12+
from .abstract_graph import AbstractGraph
1613

1714

1815
class SmartScraperGraph(AbstractGraph):
@@ -25,7 +22,6 @@ class SmartScraperGraph(AbstractGraph):
2522
prompt (str): The prompt for the graph.
2623
source (str): The source of the graph.
2724
config (dict): Configuration parameters for the graph.
28-
schema (str): The schema for the graph output.
2925
llm_model: An instance of a language model client, configured for generating answers.
3026
embedder_model: An instance of an embedding model client,
3127
configured for generating embeddings.
@@ -36,7 +32,6 @@ class SmartScraperGraph(AbstractGraph):
3632
prompt (str): The prompt for the graph.
3733
source (str): The source of the graph.
3834
config (dict): Configuration parameters for the graph.
39-
schema (str): The schema for the graph output.
4035
4136
Example:
4237
>>> smart_scraper = SmartScraperGraph(
@@ -48,8 +43,8 @@ class SmartScraperGraph(AbstractGraph):
4843
)
4944
"""
5045

51-
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
52-
super().__init__(prompt, config, source, schema)
46+
def __init__(self, prompt: str, source: str, config: dict):
47+
super().__init__(prompt, config, source)
5348

5449
self.input_key = "url" if source.startswith("http") else "local_dir"
5550

@@ -86,8 +81,7 @@ def _create_graph(self) -> BaseGraph:
8681
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
8782
output=["answer"],
8883
node_config={
89-
"llm_model": self.llm_model,
90-
"schema": self.schema,
84+
"llm_model": self.llm_model
9185
}
9286
)
9387

0 commit comments

Comments
 (0)