Skip to content

Commit b2d170c

Browse files
committed
refactored and fixed single chunk bug
1 parent 7b9a49c commit b2d170c

File tree

2 files changed

+23
-30
lines changed

2 files changed

+23
-30
lines changed

scrapegraphai/graphs/script_creator_graph.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _create_graph(self):
4646
"embedder_model": self.embedder_model
4747
}
4848
)
49-
generate_answer_node = GenerateScraperNode(
49+
generate_scraper_node = GenerateScraperNode(
5050
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
5151
output=["answer"],
5252
node_config={"llm": self.llm_model},
@@ -57,12 +57,12 @@ def _create_graph(self):
5757
fetch_node,
5858
parse_node,
5959
rag_node,
60-
generate_answer_node,
60+
generate_scraper_node,
6161
},
6262
edges={
6363
(fetch_node, parse_node),
6464
(parse_node, rag_node),
65-
(rag_node, generate_answer_node)
65+
(rag_node, generate_scraper_node)
6666
},
6767
entry_point=fetch_node
6868
)

scrapegraphai/nodes/generate_scraper_node.py

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -121,31 +121,28 @@ def execute(self, state):
121121

122122
# Use tqdm to add progress bar
123123
for i, chunk in enumerate(tqdm(doc, desc="Processing chunks")):
124-
if len(doc) == 1:
125-
prompt = PromptTemplate(
126-
template=template_no_chunks,
127-
input_variables=["question"],
128-
partial_variables={"context": chunk.page_content,
129-
"chunk_id": i + 1,
130-
"format_instructions": format_instructions},
131-
)
124+
if len(doc) > 1:
125+
template = template_chunks
132126
else:
133-
prompt = PromptTemplate(
134-
template=template_chunks,
135-
input_variables=["question"],
136-
partial_variables={"context": chunk.page_content,
137-
"chunk_id": i + 1,
138-
"format_instructions": format_instructions},
139-
)
127+
template = template_no_chunks
128+
129+
prompt = PromptTemplate(
130+
template=template,
131+
input_variables=["question"],
132+
partial_variables={"context": chunk.page_content,
133+
"chunk_id": i + 1,
134+
"format_instructions": format_instructions},
135+
)
140136
# Dynamically name the chains based on their index
141137
chain_name = f"chunk{i+1}"
142138
chains_dict[chain_name] = prompt | self.llm_model | output_parser
143139

140+
# Use dictionary unpacking to pass the dynamically named chains to RunnableParallel
141+
map_chain = RunnableParallel(**chains_dict)
142+
# Chain
143+
answer = map_chain.invoke({"question": user_prompt})
144+
144145
if len(chains_dict) > 1:
145-
# Use dictionary unpacking to pass the dynamically named chains to RunnableParallel
146-
map_chain = RunnableParallel(**chains_dict)
147-
# Chain
148-
answer_map = map_chain.invoke({"question": user_prompt})
149146

150147
# Merge the answers from the chunks
151148
merge_prompt = PromptTemplate(
@@ -155,11 +152,7 @@ def execute(self, state):
155152
)
156153
merge_chain = merge_prompt | self.llm_model | output_parser
157154
answer = merge_chain.invoke(
158-
{"context": answer_map, "question": user_prompt})
159-
160-
# Update the state with the generated answer
161-
state.update({self.output[0]: answer})
162-
return state
163-
else:
164-
state.update({self.output[0]: chains_dict})
165-
return state
155+
{"context": answer, "question": user_prompt})
156+
157+
state.update({self.output[0]: answer})
158+
return state

0 commit comments

Comments
 (0)