@@ -52,9 +52,6 @@ def __init__(
5252 super ().__init__ (node_name , "node" , input , output , 2 , node_config )
5353 self .llm_model = node_config ["llm_model" ]
5454
55- if hasattr (self .llm_model , 'request_timeout' ):
56- self .llm_model .request_timeout = node_config .get ("timeout" , 30 )
57-
5855 if isinstance (node_config ["llm_model" ], ChatOllama ):
5956 self .llm_model .format = "json"
6057
@@ -63,7 +60,22 @@ def __init__(
6360 self .script_creator = node_config .get ("script_creator" , False )
6461 self .is_md_scraper = node_config .get ("is_md_scraper" , False )
6562 self .additional_info = node_config .get ("additional_info" )
66- self .timeout = node_config .get ("timeout" , 30 )
63+ self .timeout = node_config .get ("timeout" , 120 )
64+
65+ def invoke_with_timeout (self , chain , inputs , timeout ):
66+ """Helper method to invoke chain with timeout"""
67+ try :
68+ start_time = time .time ()
69+ response = chain .invoke (inputs )
70+ if time .time () - start_time > timeout :
71+ raise Timeout (f"Response took longer than { timeout } seconds" )
72+ return response
73+ except Timeout as e :
74+ self .logger .error (f"Timeout error: { str (e )} " )
75+ raise
76+ except Exception as e :
77+ self .logger .error (f"Error during chain execution: { str (e )} " )
78+ raise
6779
6880 def execute (self , state : dict ) -> dict :
6981 """
@@ -119,39 +131,22 @@ def execute(self, state: dict) -> dict:
119131 template_chunks_prompt = self .additional_info + template_chunks_prompt
120132 template_merge_prompt = self .additional_info + template_merge_prompt
121133
122- def invoke_with_timeout (chain , inputs , timeout ):
123- try :
124- with get_openai_callback () as cb :
125- start_time = time .time ()
126- response = chain .invoke (inputs )
127- if time .time () - start_time > timeout :
128- raise Timeout (f"Response took longer than { timeout } seconds" )
129- return response
130- except Timeout as e :
131- self .logger .error (f"Timeout error: { str (e )} " )
132- raise
133- except Exception as e :
134- self .logger .error (f"Error during chain execution: { str (e )} " )
135- raise
136-
137134 if len (doc ) == 1 :
138135 prompt = PromptTemplate (
139136 template = template_no_chunks_prompt ,
140137 input_variables = ["question" ],
141138 partial_variables = {"context" : doc , "format_instructions" : format_instructions }
142139 )
143140 chain = prompt | self .llm_model
141+ if output_parser :
142+ chain = chain | output_parser
144143
145144 try :
146- raw_response = invoke_with_timeout (chain , {"question" : user_prompt }, self .timeout )
145+ answer = self . invoke_with_timeout (chain , {"question" : user_prompt }, self .timeout )
147146 except Timeout :
148147 state .update ({self .output [0 ]: {"error" : "Response timeout exceeded" }})
149148 return state
150149
151- if output_parser :
152- chain = chain | output_parser
153-
154- answer = chain .invoke ({"question" : user_prompt })
155150 state .update ({self .output [0 ]: answer })
156151 return state
157152
@@ -171,9 +166,9 @@ def invoke_with_timeout(chain, inputs, timeout):
171166
172167 async_runner = RunnableParallel (** chains_dict )
173168 try :
174- batch_results = invoke_with_timeout (
175- async_runner ,
176- {"question" : user_prompt },
169+ batch_results = self . invoke_with_timeout (
170+ async_runner ,
171+ {"question" : user_prompt },
177172 self .timeout
178173 )
179174 except Timeout :
@@ -190,7 +185,7 @@ def invoke_with_timeout(chain, inputs, timeout):
190185 if output_parser :
191186 merge_chain = merge_chain | output_parser
192187 try :
193- answer = invoke_with_timeout (
188+ answer = self . invoke_with_timeout (
194189 merge_chain ,
195190 {"context" : batch_results , "question" : user_prompt },
196191 self .timeout
0 commit comments