|
1 | 1 | """ |
2 | 2 | GenerateAnswerNode Module |
3 | 3 | """ |
4 | | - |
| 4 | +import asyncio |
5 | 5 | from typing import List, Optional |
6 | 6 | from langchain.prompts import PromptTemplate |
7 | 7 | from langchain_core.output_parsers import JsonOutputParser |
@@ -107,44 +107,43 @@ def execute(self, state: dict) -> dict: |
107 | 107 | template_chunks_prompt = self.additional_info + template_chunks_prompt |
108 | 108 | template_merge_prompt = self.additional_info + template_merge_prompt |
109 | 109 |
|
110 | | - chains_dict = {} |
| 110 | + if len(doc) == 1: |
| 111 | + prompt = PromptTemplate( |
| 112 | + template=template_no_chunks_prompt, |
| 113 | + input_variables=["question"], |
| 114 | + partial_variables={"context": doc, |
| 115 | + "format_instructions": format_instructions}) |
| 116 | + chain = prompt | self.llm_model | output_parser |
| 117 | + answer = chain.invoke({"question": user_prompt}) |
| 118 | + |
| 119 | + state.update({self.output[0]: answer}) |
| 120 | + return state |
111 | 121 |
|
112 | | - # Use tqdm to add progress bar |
| 122 | + chains_dict = {} |
113 | 123 | for i, chunk in enumerate(tqdm(doc, desc="Processing chunks", disable=not self.verbose)): |
114 | | - if len(doc) == 1: |
115 | | - prompt = PromptTemplate( |
116 | | - template=template_no_chunks_prompt, |
117 | | - input_variables=["question"], |
118 | | - partial_variables={"context": chunk, |
119 | | - "format_instructions": format_instructions}) |
120 | | - chain = prompt | self.llm_model | output_parser |
121 | | - answer = chain.invoke({"question": user_prompt}) |
122 | | - break |
123 | 124 |
|
124 | 125 | prompt = PromptTemplate( |
125 | | - template=template_chunks_prompt, |
126 | | - input_variables=["question"], |
127 | | - partial_variables={"context": chunk, |
128 | | - "chunk_id": i + 1, |
129 | | - "format_instructions": format_instructions}) |
130 | | - # Dynamically name the chains based on their index |
| 126 | + template=template_chunks, |
| 127 | + input_variables=["question"], |
| 128 | + partial_variables={"context": chunk, |
| 129 | + "chunk_id": i + 1, |
| 130 | + "format_instructions": format_instructions}) |
| 131 | + # Add chain to dictionary with dynamic name |
131 | 132 | chain_name = f"chunk{i+1}" |
132 | 133 | chains_dict[chain_name] = prompt | self.llm_model | output_parser |
133 | 134 |
|
134 | | - if len(chains_dict) > 1: |
135 | | - # Use dictionary unpacking to pass the dynamically named chains to RunnableParallel |
136 | | - map_chain = RunnableParallel(**chains_dict) |
137 | | - # Chain |
138 | | - answer = map_chain.invoke({"question": user_prompt}) |
139 | | - # Merge the answers from the chunks |
140 | | - merge_prompt = PromptTemplate( |
| 135 | + async_runner = RunnableParallel(**chains_dict) |
| 136 | + |
| 137 | + batch_results = async_runner.invoke({"question": user_prompt}) |
| 138 | + |
| 139 | + merge_prompt = PromptTemplate( |
141 | 140 | template = template_merge_prompt, |
142 | 141 | input_variables=["context", "question"], |
143 | 142 | partial_variables={"format_instructions": format_instructions}, |
144 | 143 | ) |
145 | | - merge_chain = merge_prompt | self.llm_model | output_parser |
146 | | - answer = merge_chain.invoke({"context": answer, "question": user_prompt}) |
147 | 144 |
|
148 | | - # Update the state with the generated answer |
| 145 | + merge_chain = merge_prompt | self.llm_model | output_parser |
| 146 | + answer = merge_chain.invoke({"context": batch_results, "question": user_prompt}) |
| 147 | + |
149 | 148 | state.update({self.output[0]: answer}) |
150 | 149 | return state |
0 commit comments