22gg
33Module for generating the answer node
44"""
5+
56# Imports from standard library
67from typing import List , Optional
7- from tqdm import tqdm
88
99# Imports from Langchain
1010from langchain .prompts import PromptTemplate
1111from langchain_core .output_parsers import JsonOutputParser
1212from langchain_core .runnables import RunnableParallel
13+ from tqdm import tqdm
14+
1315from ..utils .logging import get_logger
1416
1517# Imports from the library
@@ -25,24 +27,29 @@ class GenerateAnswerCSVNode(BaseNode):
2527
2628 Attributes:
2729 llm_model: An instance of a language model client, configured for generating answers.
28- node_name (str): The unique identifier name for the node, defaulting
30+ node_name (str): The unique identifier name for the node, defaulting
2931 to "GenerateAnswerNodeCsv".
30- node_type (str): The type of the node, set to "node" indicating a
32+ node_type (str): The type of the node, set to "node" indicating a
3133 standard operational node.
3234
3335 Args:
34- llm_model: An instance of the language model client (e.g., ChatOpenAI) used
36+ llm_model: An instance of the language model client (e.g., ChatOpenAI) used
3537 for generating answers.
36- node_name (str, optional): The unique identifier name for the node.
38+ node_name (str, optional): The unique identifier name for the node.
3739 Defaults to "GenerateAnswerNodeCsv".
3840
3941 Methods:
4042 execute(state): Processes the input and document from the state to generate an answer,
4143 updating the state with the generated answer under the 'answer' key.
4244 """
4345
44- def __init__ (self , input : str , output : List [str ], node_config : Optional [dict ] = None ,
45- node_name : str = "GenerateAnswer" ):
46+ def __init__ (
47+ self ,
48+ input : str ,
49+ output : List [str ],
50+ node_config : Optional [dict ] = None ,
51+ node_name : str = "GenerateAnswer" ,
52+ ):
4653 """
4754 Initializes the GenerateAnswerNodeCsv with a language model client and a node name.
4855 Args:
@@ -51,8 +58,9 @@ def __init__(self, input: str, output: List[str], node_config: Optional[dict] =
5158 """
5259 super ().__init__ (node_name , "node" , input , output , 2 , node_config )
5360 self .llm_model = node_config ["llm_model" ]
54- self .verbose = False if node_config is None else node_config .get (
55- "verbose" , False )
61+ self .verbose = (
62+ False if node_config is None else node_config .get ("verbose" , False )
63+ )
5664
5765 def execute (self , state ):
5866 """
@@ -73,8 +81,7 @@ def execute(self, state):
7381 that the necessary information for generating an answer is missing.
7482 """
7583
76- if self .verbose :
77- self .logger .info (f"--- Executing { self .node_name } Node ---" )
84+ self .logger .info (f"--- Executing { self .node_name } Node ---" )
7885
7986 # Interpret input keys based on the provided input expression
8087 input_keys = self .get_input_keys (state )
@@ -122,21 +129,27 @@ def execute(self, state):
122129 chains_dict = {}
123130
124131 # Use tqdm to add progress bar
125- for i , chunk in enumerate (tqdm (doc , desc = "Processing chunks" , disable = not self .verbose )):
132+ for i , chunk in enumerate (
133+ tqdm (doc , desc = "Processing chunks" , disable = not self .verbose )
134+ ):
126135 if len (doc ) == 1 :
127136 prompt = PromptTemplate (
128137 template = template_no_chunks ,
129138 input_variables = ["question" ],
130- partial_variables = {"context" : chunk .page_content ,
131- "format_instructions" : format_instructions },
139+ partial_variables = {
140+ "context" : chunk .page_content ,
141+ "format_instructions" : format_instructions ,
142+ },
132143 )
133144 else :
134145 prompt = PromptTemplate (
135146 template = template_chunks ,
136147 input_variables = ["question" ],
137- partial_variables = {"context" : chunk .page_content ,
138- "chunk_id" : i + 1 ,
139- "format_instructions" : format_instructions },
148+ partial_variables = {
149+ "context" : chunk .page_content ,
150+ "chunk_id" : i + 1 ,
151+ "format_instructions" : format_instructions ,
152+ },
140153 )
141154
142155 # Dynamically name the chains based on their index
@@ -155,8 +168,7 @@ def execute(self, state):
155168 partial_variables = {"format_instructions" : format_instructions },
156169 )
157170 merge_chain = merge_prompt | self .llm_model | output_parser
158- answer = merge_chain .invoke (
159- {"context" : answer , "question" : user_prompt })
171+ answer = merge_chain .invoke ({"context" : answer , "question" : user_prompt })
160172 else :
161173 # Chain
162174 single_chain = list (chains_dict .values ())[0 ]
0 commit comments