Skip to content

Commit 909af8d

Browse files
committed
refactor gen answ node
1 parent 1774b18 commit 909af8d

File tree

3 files changed

+22
-59
lines changed

3 files changed

+22
-59
lines changed

scrapegraphai/graphs/abstract_graph.py

Lines changed: 16 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,16 @@
88
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
99
from langchain_google_genai import GoogleGenerativeAIEmbeddings
1010
from ..helpers import models_tokens
11-
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Anthropic, DeepSeek
11+
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Anthropic
1212
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
1313

14-
from ..helpers import models_tokens
15-
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Anthropic, DeepSeek
16-
17-
1814
class AbstractGraph(ABC):
1915
"""
2016
Scaffolding class for creating a graph representation and executing it.
2117
2218
prompt (str): The prompt for the graph.
2319
source (str): The source of the graph.
2420
config (dict): Configuration parameters for the graph.
25-
schema (str): The schema for the graph output.
2621
llm_model: An instance of a language model client, configured for generating answers.
2722
embedder_model: An instance of an embedding model client,
2823
configured for generating embeddings.
@@ -33,7 +28,6 @@ class AbstractGraph(ABC):
3328
prompt (str): The prompt for the graph.
3429
config (dict): Configuration parameters for the graph.
3530
source (str, optional): The source of the graph.
36-
schema (str, optional): The schema for the graph output.
3731
3832
Example:
3933
>>> class MyGraph(AbstractGraph):
@@ -45,42 +39,34 @@ class AbstractGraph(ABC):
4539
>>> result = my_graph.run()
4640
"""
4741

48-
def __init__(self, prompt: str, config: dict, source: Optional[str] = None, schema: Optional[str] = None):
42+
def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
4943

5044
self.prompt = prompt
5145
self.source = source
5246
self.config = config
53-
self.schema = schema
5447
self.llm_model = self._create_llm(config["llm"], chat=True)
5548
self.embedder_model = self._create_default_embedder(llm_config=config["llm"]
5649
) if "embeddings" not in config else self._create_embedder(
5750
config["embeddings"])
58-
self.verbose = False if config is None else config.get(
59-
"verbose", False)
60-
self.headless = True if config is None else config.get(
61-
"headless", True)
62-
self.loader_kwargs = config.get("loader_kwargs", {})
6351

6452
# Create the graph
6553
self.graph = self._create_graph()
6654
self.final_state = None
6755
self.execution_info = None
6856

6957
# Set common configuration parameters
58+
7059
self.verbose = False if config is None else config.get(
7160
"verbose", False)
7261
self.headless = True if config is None else config.get(
7362
"headless", True)
7463
self.loader_kwargs = config.get("loader_kwargs", {})
7564

76-
common_params = {
77-
"headless": self.headless,
78-
"verbose": self.verbose,
79-
"loader_kwargs": self.loader_kwargs,
80-
"llm_model": self.llm_model,
81-
"embedder_model": self.embedder_model
82-
}
83-
65+
common_params = {"headless": self.headless,
66+
67+
"loader_kwargs": self.loader_kwargs,
68+
"llm_model": self.llm_model,
69+
"embedder_model": self.embedder_model}
8470
self.set_common_params(common_params, overwrite=False)
8571

8672
def set_common_params(self, params: dict, overwrite=False):
@@ -93,7 +79,7 @@ def set_common_params(self, params: dict, overwrite=False):
9379

9480
for node in self.graph.nodes:
9581
node.update_config(params, overwrite)
96-
82+
9783
def _set_model_token(self, llm):
9884

9985
if 'Azure' in str(type(llm)):
@@ -171,7 +157,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
171157
raise KeyError("Model not supported") from exc
172158
return Anthropic(llm_params)
173159
elif "ollama" in llm_params["model"]:
174-
llm_params["model"] = llm_params["model"].split("ollama/")[-1]
160+
llm_params["model"] = llm_params["model"].split("/")[-1]
175161

176162
# allow user to set model_tokens in config
177163
try:
@@ -245,8 +231,6 @@ def _create_default_embedder(self, llm_config=None) -> object:
245231
model="models/embedding-001")
246232
if isinstance(self.llm_model, OpenAI):
247233
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
248-
elif isinstance(self.llm_model, DeepSeek):
249-
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
250234
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
251235
return self.llm_model
252236
elif isinstance(self.llm_model, AzureOpenAI):
@@ -282,31 +266,30 @@ def _create_embedder(self, embedder_config: dict) -> object:
282266
if 'model_instance' in embedder_config:
283267
return embedder_config['model_instance']
284268
# Instantiate the embedding model based on the model name
285-
if "openai" in embedder_config["model"].split("/")[0]:
269+
if "openai" in embedder_config["model"]:
286270
return OpenAIEmbeddings(api_key=embedder_config["api_key"])
287271
elif "azure" in embedder_config["model"]:
288272
return AzureOpenAIEmbeddings()
289-
elif "ollama" in embedder_config["model"].split("/")[0]:
290-
print("ciao")
291-
embedder_config["model"] = embedder_config["model"].split("ollama/")[-1]
273+
elif "ollama" in embedder_config["model"]:
274+
embedder_config["model"] = embedder_config["model"].split("/")[-1]
292275
try:
293276
models_tokens["ollama"][embedder_config["model"]]
294277
except KeyError as exc:
295278
raise KeyError("Model not supported") from exc
296279
return OllamaEmbeddings(**embedder_config)
297-
elif "hugging_face" in embedder_config["model"].split("/")[0]:
280+
elif "hugging_face" in embedder_config["model"]:
298281
try:
299282
models_tokens["hugging_face"][embedder_config["model"]]
300283
except KeyError as exc:
301284
raise KeyError("Model not supported")from exc
302285
return HuggingFaceHubEmbeddings(model=embedder_config["model"])
303-
elif "gemini" in embedder_config["model"].split("/")[0]:
286+
elif "gemini" in embedder_config["model"]:
304287
try:
305288
models_tokens["gemini"][embedder_config["model"]]
306289
except KeyError as exc:
307290
raise KeyError("Model not supported")from exc
308291
return GoogleGenerativeAIEmbeddings(model=embedder_config["model"])
309-
elif "bedrock" in embedder_config["model"].split("/")[0]:
292+
elif "bedrock" in embedder_config["model"]:
310293
embedder_config["model"] = embedder_config["model"].split("/")[-1]
311294
client = embedder_config.get('client', None)
312295
try:

scrapegraphai/graphs/pdf_scraper_graph.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
FetchNode,
1212
ParseNode,
1313
RAGNode,
14-
GenerateAnswerNode
14+
GenerateAnswerPDFNode
1515
)
1616

1717

@@ -48,7 +48,7 @@ class PDFScraperGraph(AbstractGraph):
4848
"""
4949

5050
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
51-
super().__init__(prompt, config, source, schema)
51+
super().__init__(prompt, config, source)
5252

5353
self.input_key = "pdf" if source.endswith("pdf") else "pdf_dir"
5454

@@ -64,41 +64,21 @@ def _create_graph(self) -> BaseGraph:
6464
input='pdf | pdf_dir',
6565
output=["doc", "link_urls", "img_urls"],
6666
)
67-
parse_node = ParseNode(
68-
input="doc",
69-
output=["parsed_doc"],
70-
node_config={
71-
"chunk_size": self.model_token,
72-
}
73-
)
74-
rag_node = RAGNode(
75-
input="user_prompt & (parsed_doc | doc)",
76-
output=["relevant_chunks"],
77-
node_config={
78-
"llm_model": self.llm_model,
79-
"embedder_model": self.embedder_model,
80-
}
81-
)
82-
generate_answer_node = GenerateAnswerNode(
67+
generate_answer_node_pdf = GenerateAnswerPDFNode(
8368
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
8469
output=["answer"],
8570
node_config={
8671
"llm_model": self.llm_model,
87-
"schema": self.schema,
8872
}
8973
)
9074

9175
return BaseGraph(
9276
nodes=[
9377
fetch_node,
94-
parse_node,
95-
rag_node,
96-
generate_answer_node,
78+
generate_answer_node_pdf,
9779
],
9880
edges=[
99-
(fetch_node, parse_node),
100-
(parse_node, rag_node),
101-
(rag_node, generate_answer_node)
81+
(fetch_node, generate_answer_node_pdf)
10282
],
10383
entry_point=fetch_node
10484
)

scrapegraphai/nodes/generate_answer_pdf_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(self, input: str, output: List[str], node_config: Optional[dict] =
4949
node_name (str): name of the node
5050
"""
5151
super().__init__(node_name, "node", input, output, 2, node_config)
52-
self.llm_model = node_config["llm"]
52+
self.llm_model = node_config["llm_model"]
5353
self.verbose = False if node_config is None else node_config.get(
5454
"verbose", False)
5555

0 commit comments

Comments
 (0)