Skip to content

Commit 0c4b290

Browse files
VinciGit00DiTo97
andcommitted
feat: add generate_answer node paralellization
Co-Authored-By: Federico Minutoli <[email protected]>
1 parent 2ae19ae commit 0c4b290

File tree

1 file changed

+11
-15
lines changed

1 file changed

+11
-15
lines changed

scrapegraphai/nodes/generate_answer_node.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
"""
22
GenerateAnswerNode Module
33
"""
4-
4+
import asyncio
55
from typing import List, Optional
66
from langchain.prompts import PromptTemplate
77
from langchain_core.output_parsers import JsonOutputParser
88
from langchain_core.runnables import RunnableParallel
99
from tqdm import tqdm
10-
import asyncio
1110
from ..utils.merge_results import merge_results
1211
from ..utils.logging import get_logger
1312
from ..models import Ollama, OpenAI
@@ -136,21 +135,18 @@ def execute(self, state: dict) -> dict:
136135
chain_name = f"chunk{i+1}"
137136
chains_dict[chain_name] = prompt | self.llm_model | output_parser
138137

138+
async_runner = RunnableParallel(**chains_dict)
139139

140-
async def process_chains():
141-
async_runner = RunnableParallel()
142-
for chain_name, chain in chains_dict.items():
143-
async_runner.add(chain.ainvoke([{"question": user_prompt}] * len(doc)))
144-
145-
batch_results = await async_runner.run()
146-
return batch_results
140+
batch_results = async_runner.invoke({"question": user_prompt})
147141

148-
loop = asyncio.get_event_loop()
149-
batch_answers = loop.run_until_complete(process_chains())
142+
merge_prompt = PromptTemplate(
143+
template = template_merge_prompt,
144+
input_variables=["context", "question"],
145+
partial_variables={"format_instructions": format_instructions},
146+
)
150147

151-
# Merge batch results (assuming same structure)
152-
merged_answer = merge_results(batch_answers)
153-
answers = merged_answer
148+
merge_chain = merge_prompt | self.llm_model | output_parser
149+
answer = merge_chain.invoke({"context": batch_results, "question": user_prompt})
154150

155-
state.update({self.output[0]: answers})
151+
state.update({self.output[0]: answer})
156152
return state

0 commit comments

Comments
 (0)