diff --git a/scrapegraphai/nodes/generate_answer_node.py b/scrapegraphai/nodes/generate_answer_node.py index 0aaee25c..d65c1add 100644 --- a/scrapegraphai/nodes/generate_answer_node.py +++ b/scrapegraphai/nodes/generate_answer_node.py @@ -16,6 +16,10 @@ TEMPLATE_CHUNKS, TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE, TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD, TEMPLATE_MERGE_MD ) +from langchain.callbacks.manager import CallbackManager +from langchain.callbacks import get_openai_callback +from requests.exceptions import Timeout +import time class GenerateAnswerNode(BaseNode): """ @@ -56,6 +60,7 @@ def __init__( self.script_creator = node_config.get("script_creator", False) self.is_md_scraper = node_config.get("is_md_scraper", False) self.additional_info = node_config.get("additional_info") + self.timeout = node_config.get("timeout", 30) def execute(self, state: dict) -> dict: """ @@ -114,6 +119,21 @@ def execute(self, state: dict) -> dict: template_chunks_prompt = self.additional_info + template_chunks_prompt template_merge_prompt = self.additional_info + template_merge_prompt + def invoke_with_timeout(chain, inputs, timeout): + try: + with get_openai_callback() as cb: + start_time = time.time() + response = chain.invoke(inputs) + if time.time() - start_time > timeout: + raise Timeout(f"Response took longer than {timeout} seconds") + return response + except Timeout as e: + self.logger.error(f"Timeout error: {str(e)}") + raise + except Exception as e: + self.logger.error(f"Error during chain execution: {str(e)}") + raise + if len(doc) == 1: prompt = PromptTemplate( template=template_no_chunks_prompt, @@ -121,7 +141,11 @@ def execute(self, state: dict) -> dict: partial_variables={"context": doc, "format_instructions": format_instructions} ) chain = prompt | self.llm_model - raw_response = chain.invoke({"question": user_prompt}) + try: + raw_response = invoke_with_timeout(chain, {"question": user_prompt}, self.timeout) + except Timeout: + state.update({self.output[0]: {"error": "Response timeout exceeded"}}) + return state if output_parser: try: @@ -155,7 +179,15 @@ def execute(self, state: dict) -> dict: chains_dict[chain_name] = chains_dict[chain_name] | output_parser async_runner = RunnableParallel(**chains_dict) - batch_results = async_runner.invoke({"question": user_prompt}) + try: + batch_results = invoke_with_timeout( + async_runner, + {"question": user_prompt}, + self.timeout + ) + except Timeout: + state.update({self.output[0]: {"error": "Response timeout exceeded during chunk processing"}}) + return state merge_prompt = PromptTemplate( template=template_merge_prompt, @@ -166,7 +198,15 @@ def execute(self, state: dict) -> dict: merge_chain = merge_prompt | self.llm_model if output_parser: merge_chain = merge_chain | output_parser - answer = merge_chain.invoke({"context": batch_results, "question": user_prompt}) + try: + answer = invoke_with_timeout( + merge_chain, + {"context": batch_results, "question": user_prompt}, + self.timeout + ) + except Timeout: + state.update({self.output[0]: {"error": "Response timeout exceeded during merge"}}) + return state state.update({self.output[0]: answer}) return state