Skip to content

Commit 390ad82

Browse files
authored
Merge pull request #689 from ScrapeGraphAI/687-smartscrapermulticoncatgraph-error-with-bedrock
687 smartscrapermulticoncatgraph error with bedrock
2 parents dd0f260 + 8ce08ba commit 390ad82

File tree

4 files changed

+55
-97
lines changed

4 files changed

+55
-97
lines changed

scrapegraphai/graphs/smart_scraper_multi_concat_graph.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,19 +60,14 @@ def _create_graph(self) -> BaseGraph:
6060
BaseGraph: A graph instance representing the web scraping and searching workflow.
6161
"""
6262

63-
smart_scraper_instance = SmartScraperGraph(
64-
prompt="",
65-
source="",
66-
config=self.copy_config,
67-
schema=self.copy_schema
68-
)
69-
7063
graph_iterator_node = GraphIteratorNode(
7164
input="user_prompt & urls",
7265
output=["results"],
7366
node_config={
74-
"graph_instance": smart_scraper_instance,
75-
}
67+
"graph_instance": SmartScraperGraph,
68+
"scraper_config": self.copy_config,
69+
},
70+
schema=self.copy_schema,
7671
)
7772

7873
concat_answers_node = ConcatAnswersNode(

scrapegraphai/nodes/generate_answer_node.py

Lines changed: 47 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
"""
2-
GenerateAnswerNode Module
3-
"""
41
from typing import List, Optional
52
from langchain.prompts import PromptTemplate
63
from langchain_core.output_parsers import JsonOutputParser
@@ -12,29 +9,12 @@
129
from tqdm import tqdm
1310
from .base_node import BaseNode
1411
from ..utils.output_parser import get_structured_output_parser, get_pydantic_output_parser
15-
from ..prompts import (TEMPLATE_CHUNKS,
16-
TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE,
17-
TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD,
18-
TEMPLATE_MERGE_MD)
12+
from ..prompts import (
13+
TEMPLATE_CHUNKS, TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE,
14+
TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD, TEMPLATE_MERGE_MD
15+
)
1916

2017
class GenerateAnswerNode(BaseNode):
21-
"""
22-
A node that generates an answer using a large language model (LLM) based on the user's input
23-
and the content extracted from a webpage. It constructs a prompt from the user's input
24-
and the scraped content, feeds it to the LLM, and parses the LLM's response to produce
25-
an answer.
26-
27-
Attributes:
28-
llm_model: An instance of a language model client, configured for generating answers.
29-
verbose (bool): A flag indicating whether to show print statements during execution.
30-
31-
Args:
32-
input (str): Boolean expression defining the input keys needed from the state.
33-
output (List[str]): List of output keys to be updated in the state.
34-
node_config (dict): Additional configuration for the node.
35-
node_name (str): The unique identifier name for the node, defaulting to "GenerateAnswer".
36-
"""
37-
3818
def __init__(
3919
self,
4020
input: str,
@@ -43,121 +23,102 @@ def __init__(
4323
node_name: str = "GenerateAnswer",
4424
):
4525
super().__init__(node_name, "node", input, output, 2, node_config)
46-
4726
self.llm_model = node_config["llm_model"]
4827

4928
if isinstance(node_config["llm_model"], ChatOllama):
50-
self.llm_model.format="json"
51-
52-
self.verbose = (
53-
True if node_config is None else node_config.get("verbose", False)
54-
)
55-
self.force = (
56-
False if node_config is None else node_config.get("force", False)
57-
)
58-
self.script_creator = (
59-
False if node_config is None else node_config.get("script_creator", False)
60-
)
61-
self.is_md_scraper = (
62-
False if node_config is None else node_config.get("is_md_scraper", False)
63-
)
29+
self.llm_model.format = "json"
6430

31+
self.verbose = node_config.get("verbose", False)
32+
self.force = node_config.get("force", False)
33+
self.script_creator = node_config.get("script_creator", False)
34+
self.is_md_scraper = node_config.get("is_md_scraper", False)
6535
self.additional_info = node_config.get("additional_info")
6636

6737
def execute(self, state: dict) -> dict:
68-
"""
69-
Generates an answer by constructing a prompt from the user's input and the scraped
70-
content, querying the language model, and parsing its response.
71-
72-
Args:
73-
state (dict): The current state of the graph. The input keys will be used
74-
to fetch the correct data from the state.
75-
76-
Returns:
77-
dict: The updated state with the output key containing the generated answer.
78-
79-
Raises:
80-
KeyError: If the input keys are not found in the state, indicating
81-
that the necessary information for generating an answer is missing.
82-
"""
83-
8438
self.logger.info(f"--- Executing {self.node_name} Node ---")
8539

86-
input_keys = self.get_input_keys(state)
40+
input_keys = self.get_input_keys(state)
8741
input_data = [state[key] for key in input_keys]
8842
user_prompt = input_data[0]
8943
doc = input_data[1]
9044

9145
if self.node_config.get("schema", None) is not None:
92-
9346
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
9447
self.llm_model = self.llm_model.with_structured_output(
95-
schema = self.node_config["schema"])
48+
schema=self.node_config["schema"]
49+
)
9650
output_parser = get_structured_output_parser(self.node_config["schema"])
9751
format_instructions = "NA"
9852
else:
9953
if not isinstance(self.llm_model, ChatBedrock):
10054
output_parser = get_pydantic_output_parser(self.node_config["schema"])
10155
format_instructions = output_parser.get_format_instructions()
102-
56+
else:
57+
output_parser = None
58+
format_instructions = ""
10359
else:
10460
if not isinstance(self.llm_model, ChatBedrock):
10561
output_parser = JsonOutputParser()
10662
format_instructions = output_parser.get_format_instructions()
63+
else:
64+
output_parser = None
65+
format_instructions = ""
10766

10867
if isinstance(self.llm_model, (ChatOpenAI, AzureChatOpenAI)) \
10968
and not self.script_creator \
11069
or self.force \
11170
and not self.script_creator or self.is_md_scraper:
112-
113-
template_no_chunks_prompt = TEMPLATE_NO_CHUNKS_MD
114-
template_chunks_prompt = TEMPLATE_CHUNKS_MD
115-
template_merge_prompt = TEMPLATE_MERGE_MD
71+
template_no_chunks_prompt = TEMPLATE_NO_CHUNKS_MD
72+
template_chunks_prompt = TEMPLATE_CHUNKS_MD
73+
template_merge_prompt = TEMPLATE_MERGE_MD
11674
else:
117-
template_no_chunks_prompt = TEMPLATE_NO_CHUNKS
118-
template_chunks_prompt = TEMPLATE_CHUNKS
119-
template_merge_prompt = TEMPLATE_MERGE
75+
template_no_chunks_prompt = TEMPLATE_NO_CHUNKS
76+
template_chunks_prompt = TEMPLATE_CHUNKS
77+
template_merge_prompt = TEMPLATE_MERGE
12078

12179
if self.additional_info is not None:
122-
template_no_chunks_prompt = self.additional_info + template_no_chunks_prompt
123-
template_chunks_prompt = self.additional_info + template_chunks_prompt
124-
template_merge_prompt = self.additional_info + template_merge_prompt
80+
template_no_chunks_prompt = self.additional_info + template_no_chunks_prompt
81+
template_chunks_prompt = self.additional_info + template_chunks_prompt
82+
template_merge_prompt = self.additional_info + template_merge_prompt
12583

12684
if len(doc) == 1:
12785
prompt = PromptTemplate(
128-
template=template_no_chunks_prompt ,
86+
template=template_no_chunks_prompt,
12987
input_variables=["question"],
130-
partial_variables={"context": doc,
131-
"format_instructions": format_instructions})
132-
chain = prompt | self.llm_model | output_parser
88+
partial_variables={"context": doc, "format_instructions": format_instructions}
89+
)
90+
chain = prompt | self.llm_model
91+
if output_parser:
92+
chain = chain | output_parser
13393
answer = chain.invoke({"question": user_prompt})
13494

13595
state.update({self.output[0]: answer})
13696
return state
13797

13898
chains_dict = {}
13999
for i, chunk in enumerate(tqdm(doc, desc="Processing chunks", disable=not self.verbose)):
140-
141100
prompt = PromptTemplate(
142-
template=TEMPLATE_CHUNKS,
101+
template=template_chunks_prompt,
143102
input_variables=["question"],
144-
partial_variables={"context": chunk,
145-
"chunk_id": i + 1,
146-
"format_instructions": format_instructions})
103+
partial_variables={"context": chunk, "chunk_id": i + 1, "format_instructions": format_instructions}
104+
)
147105
chain_name = f"chunk{i+1}"
148-
chains_dict[chain_name] = prompt | self.llm_model | output_parser
106+
chains_dict[chain_name] = prompt | self.llm_model
107+
if output_parser:
108+
chains_dict[chain_name] = chains_dict[chain_name] | output_parser
149109

150110
async_runner = RunnableParallel(**chains_dict)
151-
152-
batch_results = async_runner.invoke({"question": user_prompt})
111+
batch_results = async_runner.invoke({"question": user_prompt})
153112

154113
merge_prompt = PromptTemplate(
155-
template = template_merge_prompt ,
156-
input_variables=["context", "question"],
157-
partial_variables={"format_instructions": format_instructions},
158-
)
114+
template=template_merge_prompt,
115+
input_variables=["context", "question"],
116+
partial_variables={"format_instructions": format_instructions}
117+
)
159118

160-
merge_chain = merge_prompt | self.llm_model | output_parser
119+
merge_chain = merge_prompt | self.llm_model
120+
if output_parser:
121+
merge_chain = merge_chain | output_parser
161122
answer = merge_chain.invoke({"context": batch_results, "question": user_prompt})
162123

163124
state.update({self.output[0]: answer})

scrapegraphai/nodes/graph_iterator_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ async def _async_run(graph):
130130
if url.startswith("http"):
131131
graph.input_key = "url"
132132
participants.append(graph)
133-
133+
134134
futures = [_async_run(graph) for graph in participants]
135135

136136
answers = await tqdm.gather(

scrapegraphai/utils/research_web.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ def search_on_web(query: str, search_engine: str = "Google",
6060

6161
elif search_engine.lower() == "searxng":
6262
url = f"http://localhost:{port}"
63-
params = {"q": query, "format": "json"}
63+
params = {"q": query,
64+
"format": "json",
65+
"engines": "google,duckduckgo,brave,qwant,bing"}
6466

6567
response = requests.get(url, params=params)
6668

0 commit comments

Comments
 (0)