|
1 | 1 | """ |
2 | 2 | GenerateAnswerNode Module |
3 | 3 | """ |
4 | | - |
| 4 | +import asyncio |
5 | 5 | from typing import List, Optional |
6 | 6 | from langchain.prompts import PromptTemplate |
7 | 7 | from langchain_core.output_parsers import JsonOutputParser |
8 | 8 | from langchain_core.runnables import RunnableParallel |
9 | 9 | from tqdm import tqdm |
10 | | -import asyncio |
11 | 10 | from ..utils.merge_results import merge_results |
12 | 11 | from ..utils.logging import get_logger |
13 | 12 | from ..models import Ollama, OpenAI |
@@ -136,21 +135,18 @@ def execute(self, state: dict) -> dict: |
136 | 135 | chain_name = f"chunk{i+1}" |
137 | 136 | chains_dict[chain_name] = prompt | self.llm_model | output_parser |
138 | 137 |
|
| 138 | + async_runner = RunnableParallel(**chains_dict) |
139 | 139 |
|
140 | | - async def process_chains(): |
141 | | - async_runner = RunnableParallel() |
142 | | - for chain_name, chain in chains_dict.items(): |
143 | | - async_runner.add(chain.ainvoke([{"question": user_prompt}] * len(doc))) |
144 | | - |
145 | | - batch_results = await async_runner.run() |
146 | | - return batch_results |
| 140 | + batch_results = async_runner.invoke({"question": user_prompt}) |
147 | 141 |
|
148 | | - loop = asyncio.get_event_loop() |
149 | | - batch_answers = loop.run_until_complete(process_chains()) |
| 142 | + merge_prompt = PromptTemplate( |
| 143 | + template = template_merge_prompt, |
| 144 | + input_variables=["context", "question"], |
| 145 | + partial_variables={"format_instructions": format_instructions}, |
| 146 | + ) |
150 | 147 |
|
151 | | - # Merge batch results (assuming same structure) |
152 | | - merged_answer = merge_results(batch_answers) |
153 | | - answers = merged_answer |
| 148 | + merge_chain = merge_prompt | self.llm_model | output_parser |
| 149 | + answer = merge_chain.invoke({"context": batch_results, "question": user_prompt}) |
154 | 150 |
|
155 | | - state.update({self.output[0]: answers}) |
| 151 | + state.update({self.output[0]: answer}) |
156 | 152 | return state |
0 commit comments