@@ -121,31 +121,28 @@ def execute(self, state):
121121
122122 # Use tqdm to add progress bar
123123 for i , chunk in enumerate (tqdm (doc , desc = "Processing chunks" )):
124- if len (doc ) == 1 :
125- prompt = PromptTemplate (
126- template = template_no_chunks ,
127- input_variables = ["question" ],
128- partial_variables = {"context" : chunk .page_content ,
129- "chunk_id" : i + 1 ,
130- "format_instructions" : format_instructions },
131- )
124+ if len (doc ) > 1 :
125+ template = template_chunks
132126 else :
133- prompt = PromptTemplate (
134- template = template_chunks ,
135- input_variables = ["question" ],
136- partial_variables = {"context" : chunk .page_content ,
137- "chunk_id" : i + 1 ,
138- "format_instructions" : format_instructions },
139- )
127+ template = template_no_chunks
128+
129+ prompt = PromptTemplate (
130+ template = template ,
131+ input_variables = ["question" ],
132+ partial_variables = {"context" : chunk .page_content ,
133+ "chunk_id" : i + 1 ,
134+ "format_instructions" : format_instructions },
135+ )
140136 # Dynamically name the chains based on their index
141137 chain_name = f"chunk{ i + 1 } "
142138 chains_dict [chain_name ] = prompt | self .llm_model | output_parser
143139
140+ # Use dictionary unpacking to pass the dynamically named chains to RunnableParallel
141+ map_chain = RunnableParallel (** chains_dict )
142+ # Chain
143+ answer = map_chain .invoke ({"question" : user_prompt })
144+
144145 if len (chains_dict ) > 1 :
145- # Use dictionary unpacking to pass the dynamically named chains to RunnableParallel
146- map_chain = RunnableParallel (** chains_dict )
147- # Chain
148- answer_map = map_chain .invoke ({"question" : user_prompt })
149146
150147 # Merge the answers from the chunks
151148 merge_prompt = PromptTemplate (
@@ -155,11 +152,7 @@ def execute(self, state):
155152 )
156153 merge_chain = merge_prompt | self .llm_model | output_parser
157154 answer = merge_chain .invoke (
158- {"context" : answer_map , "question" : user_prompt })
159-
160- # Update the state with the generated answer
161- state .update ({self .output [0 ]: answer })
162- return state
163- else :
164- state .update ({self .output [0 ]: chains_dict })
165- return state
155+ {"context" : answer , "question" : user_prompt })
156+
157+ state .update ({self .output [0 ]: answer })
158+ return state
0 commit comments