@@ -60,7 +60,22 @@ def __init__(
6060 self .script_creator = node_config .get ("script_creator" , False )
6161 self .is_md_scraper = node_config .get ("is_md_scraper" , False )
6262 self .additional_info = node_config .get ("additional_info" )
63- self .timeout = node_config .get ("timeout" , 30 )
63+ self .timeout = node_config .get ("timeout" , 120 )
64+
65+ def invoke_with_timeout (self , chain , inputs , timeout ):
66+ """Helper method to invoke chain with timeout"""
67+ try :
68+ start_time = time .time ()
69+ response = chain .invoke (inputs )
70+ if time .time () - start_time > timeout :
71+ raise Timeout (f"Response took longer than { timeout } seconds" )
72+ return response
73+ except Timeout as e :
74+ self .logger .error (f"Timeout error: { str (e )} " )
75+ raise
76+ except Exception as e :
77+ self .logger .error (f"Error during chain execution: { str (e )} " )
78+ raise
6479
6580 def execute (self , state : dict ) -> dict :
6681 """
@@ -116,39 +131,22 @@ def execute(self, state: dict) -> dict:
116131 template_chunks_prompt = self .additional_info + template_chunks_prompt
117132 template_merge_prompt = self .additional_info + template_merge_prompt
118133
119- def invoke_with_timeout (chain , inputs , timeout ):
120- try :
121- with get_openai_callback () as cb :
122- start_time = time .time ()
123- response = chain .invoke (inputs )
124- if time .time () - start_time > timeout :
125- raise Timeout (f"Response took longer than { timeout } seconds" )
126- return response
127- except Timeout as e :
128- self .logger .error (f"Timeout error: { str (e )} " )
129- raise
130- except Exception as e :
131- self .logger .error (f"Error during chain execution: { str (e )} " )
132- raise
133-
134134 if len (doc ) == 1 :
135135 prompt = PromptTemplate (
136136 template = template_no_chunks_prompt ,
137137 input_variables = ["question" ],
138138 partial_variables = {"context" : doc , "format_instructions" : format_instructions }
139139 )
140140 chain = prompt | self .llm_model
141+ if output_parser :
142+ chain = chain | output_parser
141143
142144 try :
143- raw_response = invoke_with_timeout (chain , {"question" : user_prompt }, self .timeout )
145+ answer = self . invoke_with_timeout (chain , {"question" : user_prompt }, self .timeout )
144146 except Timeout :
145147 state .update ({self .output [0 ]: {"error" : "Response timeout exceeded" }})
146148 return state
147149
148- if output_parser :
149- chain = chain | output_parser
150-
151- answer = chain .invoke ({"question" : user_prompt })
152150 state .update ({self .output [0 ]: answer })
153151 return state
154152
@@ -168,9 +166,9 @@ def invoke_with_timeout(chain, inputs, timeout):
168166
169167 async_runner = RunnableParallel (** chains_dict )
170168 try :
171- batch_results = invoke_with_timeout (
172- async_runner ,
173- {"question" : user_prompt },
169+ batch_results = self . invoke_with_timeout (
170+ async_runner ,
171+ {"question" : user_prompt },
174172 self .timeout
175173 )
176174 except Timeout :
@@ -187,7 +185,7 @@ def invoke_with_timeout(chain, inputs, timeout):
187185 if output_parser :
188186 merge_chain = merge_chain | output_parser
189187 try :
190- answer = invoke_with_timeout (
188+ answer = self . invoke_with_timeout (
191189 merge_chain ,
192190 {"context" : batch_results , "question" : user_prompt },
193191 self .timeout
0 commit comments