Skip to content

Commit 8fc4187

Browse files
committed
refactoring of generate answer node
1 parent 4c74e01 commit 8fc4187

File tree

1 file changed

+24
-18
lines changed

1 file changed

+24
-18
lines changed

scrapegraphai/nodes/generate_answer_node.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -113,21 +113,27 @@ def execute(self, state):
113113
chain_name = f"chunk{i+1}"
114114
chains_dict[chain_name] = prompt | self.llm_model | output_parser
115115

116-
# Use dictionary unpacking to pass the dynamically named chains to RunnableParallel
117-
map_chain = RunnableParallel(**chains_dict)
118-
# Chain
119-
answer_map = map_chain.invoke({"question": user_prompt})
120-
121-
# Merge the answers from the chunks
122-
merge_prompt = PromptTemplate(
123-
template=template_merge,
124-
input_variables=["context", "question"],
125-
partial_variables={"format_instructions": format_instructions},
126-
)
127-
merge_chain = merge_prompt | self.llm_model | output_parser
128-
answer = merge_chain.invoke(
129-
{"context": answer_map, "question": user_prompt})
130-
131-
# Update the state with the generated answer
132-
state.update({self.output[0]: answer})
133-
return state
116+
if len(chains_dict) > 1:
117+
# Use dictionary unpacking to pass the dynamically named chains to RunnableParallel
118+
map_chain = RunnableParallel(**chains_dict)
119+
# Chain
120+
answer_map = map_chain.invoke({"question": user_prompt})
121+
122+
# Merge the answers from the chunks
123+
merge_prompt = PromptTemplate(
124+
template=template_merge,
125+
input_variables=["context", "question"],
126+
partial_variables={"format_instructions": format_instructions},
127+
)
128+
merge_chain = merge_prompt | self.llm_model | output_parser
129+
answer = merge_chain.invoke(
130+
{"context": answer_map, "question": user_prompt})
131+
132+
# Update the state with the generated answer
133+
state.update({self.output[0]: answer})
134+
return state
135+
136+
else:
137+
# Update the state with the generated answer
138+
state.update({self.output[0]: chains_dict})
139+
return state

0 commit comments

Comments
 (0)