1+ """
2+ GenerateAnswerNode Module
3+ """
14from typing import List , Optional
25from langchain .prompts import PromptTemplate
36from langchain_core .output_parsers import JsonOutputParser
1518)
1619
1720class GenerateAnswerNode (BaseNode ):
21+ """
22+ Initializes the GenerateAnswerNode class.
23+
24+ Args:
25+ input (str): The input data type for the node.
26+ output (List[str]): The output data type(s) for the node.
27+ node_config (Optional[dict]): Configuration dictionary for the node,
28+ which includes the LLM model, verbosity, schema, and other settings.
29+ Defaults to None.
30+ node_name (str): The name of the node. Defaults to "GenerateAnswer".
31+
32+ Attributes:
33+ llm_model: The language model specified in the node configuration.
34+ verbose (bool): Whether verbose mode is enabled.
35+ force (bool): Whether to force certain behaviors, overriding defaults.
36+ script_creator (bool): Whether the node is in script creation mode.
37+ is_md_scraper (bool): Whether the node is scraping markdown data.
38+ additional_info (Optional[str]): Any additional information to be
39+ included in the prompt templates.
40+ """
1841 def __init__ (
1942 self ,
2043 input : str ,
@@ -34,7 +57,17 @@ def __init__(
3457 self .is_md_scraper = node_config .get ("is_md_scraper" , False )
3558 self .additional_info = node_config .get ("additional_info" )
3659
37- def execute (self , state : dict ) -> dict :
60+ async def execute (self , state : dict ) -> dict :
61+ """
62+ Executes the GenerateAnswerNode.
63+
64+ Args:
65+ state (dict): The current state of the graph. The input keys will be used
66+ to fetch the correct data from the state.
67+
68+ Returns:
69+ dict: The updated state with the output key containing the generated answer.
70+ """
3871 self .logger .info (f"--- Executing { self .node_name } Node ---" )
3972
4073 input_keys = self .get_input_keys (state )
@@ -90,7 +123,7 @@ def execute(self, state: dict) -> dict:
90123 chain = prompt | self .llm_model
91124 if output_parser :
92125 chain = chain | output_parser
93- answer = chain .ainvoke ({"question" : user_prompt })
126+ answer = await chain .ainvoke ({"question" : user_prompt })
94127
95128 state .update ({self .output [0 ]: answer })
96129 return state
@@ -110,7 +143,7 @@ def execute(self, state: dict) -> dict:
110143 chains_dict [chain_name ] = chains_dict [chain_name ] | output_parser
111144
112145 async_runner = RunnableParallel (** chains_dict )
113- batch_results = async_runner .invoke ({"question" : user_prompt })
146+ batch_results = await async_runner .ainvoke ({"question" : user_prompt })
114147
115148 merge_prompt = PromptTemplate (
116149 template = template_merge_prompt ,
@@ -121,7 +154,7 @@ def execute(self, state: dict) -> dict:
121154 merge_chain = merge_prompt | self .llm_model
122155 if output_parser :
123156 merge_chain = merge_chain | output_parser
124- answer = merge_chain .ainvoke ({"context" : batch_results , "question" : user_prompt })
157+ answer = await merge_chain .ainvoke ({"context" : batch_results , "question" : user_prompt })
125158
126159 state .update ({self .output [0 ]: answer })
127160 return state
0 commit comments