Skip to content

Commit 69880b6

Browse files
committed
Update generate_answer_node.py
1 parent 65b8675 commit 69880b6

File tree

1 file changed

+53
-90
lines changed

1 file changed

+53
-90
lines changed

scrapegraphai/nodes/generate_answer_node.py

Lines changed: 53 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
"""
2-
GenerateAnswerNode Module
3-
"""
41
from typing import List, Optional
52
from langchain.prompts import PromptTemplate
63
from langchain_core.output_parsers import JsonOutputParser
@@ -12,29 +9,12 @@
129
from tqdm import tqdm
1310
from .base_node import BaseNode
1411
from ..utils.output_parser import get_structured_output_parser, get_pydantic_output_parser
15-
from ..prompts import (TEMPLATE_CHUNKS,
16-
TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE,
17-
TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD,
18-
TEMPLATE_MERGE_MD)
12+
from ..prompts import (
13+
TEMPLATE_CHUNKS, TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE,
14+
TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD, TEMPLATE_MERGE_MD
15+
)
1916

2017
class GenerateAnswerNode(BaseNode):
21-
"""
22-
A node that generates an answer using a large language model (LLM) based on the user's input
23-
and the content extracted from a webpage. It constructs a prompt from the user's input
24-
and the scraped content, feeds it to the LLM, and parses the LLM's response to produce
25-
an answer.
26-
27-
Attributes:
28-
llm_model: An instance of a language model client, configured for generating answers.
29-
verbose (bool): A flag indicating whether to show print statements during execution.
30-
31-
Args:
32-
input (str): Boolean expression defining the input keys needed from the state.
33-
output (List[str]): List of output keys to be updated in the state.
34-
node_config (dict): Additional configuration for the node.
35-
node_name (str): The unique identifier name for the node, defaulting to "GenerateAnswer".
36-
"""
37-
3818
def __init__(
3919
self,
4020
input: str,
@@ -43,119 +23,102 @@ def __init__(
4323
node_name: str = "GenerateAnswer",
4424
):
4525
super().__init__(node_name, "node", input, output, 2, node_config)
46-
4726
self.llm_model = node_config["llm_model"]
4827

4928
if isinstance(node_config["llm_model"], ChatOllama):
50-
self.llm_model.format="json"
51-
52-
self.verbose = (
53-
True if node_config is None else node_config.get("verbose", False)
54-
)
55-
self.force = (
56-
False if node_config is None else node_config.get("force", False)
57-
)
58-
self.script_creator = (
59-
False if node_config is None else node_config.get("script_creator", False)
60-
)
61-
self.is_md_scraper = (
62-
False if node_config is None else node_config.get("is_md_scraper", False)
63-
)
29+
self.llm_model.format = "json"
6430

31+
self.verbose = node_config.get("verbose", False)
32+
self.force = node_config.get("force", False)
33+
self.script_creator = node_config.get("script_creator", False)
34+
self.is_md_scraper = node_config.get("is_md_scraper", False)
6535
self.additional_info = node_config.get("additional_info")
6636

6737
def execute(self, state: dict) -> dict:
68-
"""
69-
Generates an answer by constructing a prompt from the user's input and the scraped
70-
content, querying the language model, and parsing its response.
71-
72-
Args:
73-
state (dict): The current state of the graph. The input keys will be used
74-
to fetch the correct data from the state.
75-
76-
Returns:
77-
dict: The updated state with the output key containing the generated answer.
78-
79-
Raises:
80-
KeyError: If the input keys are not found in the state, indicating
81-
that the necessary information for generating an answer is missing.
82-
"""
83-
8438
self.logger.info(f"--- Executing {self.node_name} Node ---")
8539

86-
input_keys = self.get_input_keys(state)
40+
input_keys = self.get_input_keys(state)
8741
input_data = [state[key] for key in input_keys]
8842
user_prompt = input_data[0]
8943
doc = input_data[1]
9044

9145
if self.node_config.get("schema", None) is not None:
92-
9346
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
9447
self.llm_model = self.llm_model.with_structured_output(
95-
schema = self.node_config["schema"])
48+
schema=self.node_config["schema"]
49+
)
9650
output_parser = get_structured_output_parser(self.node_config["schema"])
9751
format_instructions = "NA"
9852
else:
99-
output_parser = get_pydantic_output_parser(self.node_config["schema"])
100-
format_instructions = output_parser.get_format_instructions()
101-
53+
if not isinstance(self.llm_model, ChatBedrock):
54+
output_parser = get_pydantic_output_parser(self.node_config["schema"])
55+
format_instructions = output_parser.get_format_instructions()
56+
else:
57+
output_parser = None
58+
format_instructions = ""
10259
else:
103-
output_parser = JsonOutputParser()
104-
format_instructions = output_parser.get_format_instructions()
60+
if not isinstance(self.llm_model, ChatBedrock):
61+
output_parser = JsonOutputParser()
62+
format_instructions = output_parser.get_format_instructions()
63+
else:
64+
output_parser = None
65+
format_instructions = ""
10566

10667
if isinstance(self.llm_model, (ChatOpenAI, AzureChatOpenAI)) \
10768
and not self.script_creator \
10869
or self.force \
10970
and not self.script_creator or self.is_md_scraper:
110-
111-
template_no_chunks_prompt = TEMPLATE_NO_CHUNKS_MD
112-
template_chunks_prompt = TEMPLATE_CHUNKS_MD
113-
template_merge_prompt = TEMPLATE_MERGE_MD
71+
template_no_chunks_prompt = TEMPLATE_NO_CHUNKS_MD
72+
template_chunks_prompt = TEMPLATE_CHUNKS_MD
73+
template_merge_prompt = TEMPLATE_MERGE_MD
11474
else:
115-
template_no_chunks_prompt = TEMPLATE_NO_CHUNKS
116-
template_chunks_prompt = TEMPLATE_CHUNKS
117-
template_merge_prompt = TEMPLATE_MERGE
75+
template_no_chunks_prompt = TEMPLATE_NO_CHUNKS
76+
template_chunks_prompt = TEMPLATE_CHUNKS
77+
template_merge_prompt = TEMPLATE_MERGE
11878

11979
if self.additional_info is not None:
120-
template_no_chunks_prompt = self.additional_info + template_no_chunks_prompt
121-
template_chunks_prompt = self.additional_info + template_chunks_prompt
122-
template_merge_prompt = self.additional_info + template_merge_prompt
80+
template_no_chunks_prompt = self.additional_info + template_no_chunks_prompt
81+
template_chunks_prompt = self.additional_info + template_chunks_prompt
82+
template_merge_prompt = self.additional_info + template_merge_prompt
12383

12484
if len(doc) == 1:
12585
prompt = PromptTemplate(
126-
template=template_no_chunks_prompt ,
86+
template=template_no_chunks_prompt,
12787
input_variables=["question"],
128-
partial_variables={"context": doc,
129-
"format_instructions": format_instructions})
130-
chain = prompt | self.llm_model | output_parser
88+
partial_variables={"context": doc, "format_instructions": format_instructions}
89+
)
90+
chain = prompt | self.llm_model
91+
if output_parser:
92+
chain = chain | output_parser
13193
answer = chain.invoke({"question": user_prompt})
13294

13395
state.update({self.output[0]: answer})
13496
return state
13597

13698
chains_dict = {}
13799
for i, chunk in enumerate(tqdm(doc, desc="Processing chunks", disable=not self.verbose)):
138-
139100
prompt = PromptTemplate(
140-
template=TEMPLATE_CHUNKS,
101+
template=template_chunks_prompt,
141102
input_variables=["question"],
142-
partial_variables={"context": chunk,
143-
"chunk_id": i + 1,
144-
"format_instructions": format_instructions})
103+
partial_variables={"context": chunk, "chunk_id": i + 1, "format_instructions": format_instructions}
104+
)
145105
chain_name = f"chunk{i+1}"
146-
chains_dict[chain_name] = prompt | self.llm_model | output_parser
106+
chains_dict[chain_name] = prompt | self.llm_model
107+
if output_parser:
108+
chains_dict[chain_name] = chains_dict[chain_name] | output_parser
147109

148110
async_runner = RunnableParallel(**chains_dict)
149-
150-
batch_results = async_runner.invoke({"question": user_prompt})
111+
batch_results = async_runner.invoke({"question": user_prompt})
151112

152113
merge_prompt = PromptTemplate(
153-
template = template_merge_prompt ,
154-
input_variables=["context", "question"],
155-
partial_variables={"format_instructions": format_instructions},
156-
)
114+
template=template_merge_prompt,
115+
input_variables=["context", "question"],
116+
partial_variables={"format_instructions": format_instructions}
117+
)
157118

158-
merge_chain = merge_prompt | self.llm_model | output_parser
119+
merge_chain = merge_prompt | self.llm_model
120+
if output_parser:
121+
merge_chain = merge_chain | output_parser
159122
answer = merge_chain.invoke({"context": batch_results, "question": user_prompt})
160123

161124
state.update({self.output[0]: answer})

0 commit comments

Comments
 (0)