1- """
2- GenerateAnswerNode Module
3- """
41from typing import List , Optional
52from langchain .prompts import PromptTemplate
63from langchain_core .output_parsers import JsonOutputParser
129from tqdm import tqdm
1310from .base_node import BaseNode
1411from ..utils .output_parser import get_structured_output_parser , get_pydantic_output_parser
15- from ..prompts import (TEMPLATE_CHUNKS ,
16- TEMPLATE_NO_CHUNKS , TEMPLATE_MERGE ,
17- TEMPLATE_CHUNKS_MD , TEMPLATE_NO_CHUNKS_MD ,
18- TEMPLATE_MERGE_MD )
12+ from ..prompts import (
13+ TEMPLATE_CHUNKS , TEMPLATE_NO_CHUNKS , TEMPLATE_MERGE ,
14+ TEMPLATE_CHUNKS_MD , TEMPLATE_NO_CHUNKS_MD , TEMPLATE_MERGE_MD
15+ )
1916
2017class GenerateAnswerNode (BaseNode ):
21- """
22- A node that generates an answer using a large language model (LLM) based on the user's input
23- and the content extracted from a webpage. It constructs a prompt from the user's input
24- and the scraped content, feeds it to the LLM, and parses the LLM's response to produce
25- an answer.
26-
27- Attributes:
28- llm_model: An instance of a language model client, configured for generating answers.
29- verbose (bool): A flag indicating whether to show print statements during execution.
30-
31- Args:
32- input (str): Boolean expression defining the input keys needed from the state.
33- output (List[str]): List of output keys to be updated in the state.
34- node_config (dict): Additional configuration for the node.
35- node_name (str): The unique identifier name for the node, defaulting to "GenerateAnswer".
36- """
37-
3818 def __init__ (
3919 self ,
4020 input : str ,
@@ -43,119 +23,102 @@ def __init__(
4323 node_name : str = "GenerateAnswer" ,
4424 ):
4525 super ().__init__ (node_name , "node" , input , output , 2 , node_config )
46-
4726 self .llm_model = node_config ["llm_model" ]
4827
4928 if isinstance (node_config ["llm_model" ], ChatOllama ):
50- self .llm_model .format = "json"
51-
52- self .verbose = (
53- True if node_config is None else node_config .get ("verbose" , False )
54- )
55- self .force = (
56- False if node_config is None else node_config .get ("force" , False )
57- )
58- self .script_creator = (
59- False if node_config is None else node_config .get ("script_creator" , False )
60- )
61- self .is_md_scraper = (
62- False if node_config is None else node_config .get ("is_md_scraper" , False )
63- )
29+ self .llm_model .format = "json"
6430
31+ self .verbose = node_config .get ("verbose" , False )
32+ self .force = node_config .get ("force" , False )
33+ self .script_creator = node_config .get ("script_creator" , False )
34+ self .is_md_scraper = node_config .get ("is_md_scraper" , False )
6535 self .additional_info = node_config .get ("additional_info" )
6636
6737 def execute (self , state : dict ) -> dict :
68- """
69- Generates an answer by constructing a prompt from the user's input and the scraped
70- content, querying the language model, and parsing its response.
71-
72- Args:
73- state (dict): The current state of the graph. The input keys will be used
74- to fetch the correct data from the state.
75-
76- Returns:
77- dict: The updated state with the output key containing the generated answer.
78-
79- Raises:
80- KeyError: If the input keys are not found in the state, indicating
81- that the necessary information for generating an answer is missing.
82- """
83-
8438 self .logger .info (f"--- Executing { self .node_name } Node ---" )
8539
86- input_keys = self .get_input_keys (state )
40+ input_keys = self .get_input_keys (state )
8741 input_data = [state [key ] for key in input_keys ]
8842 user_prompt = input_data [0 ]
8943 doc = input_data [1 ]
9044
9145 if self .node_config .get ("schema" , None ) is not None :
92-
9346 if isinstance (self .llm_model , (ChatOpenAI , ChatMistralAI )):
9447 self .llm_model = self .llm_model .with_structured_output (
95- schema = self .node_config ["schema" ])
48+ schema = self .node_config ["schema" ]
49+ )
9650 output_parser = get_structured_output_parser (self .node_config ["schema" ])
9751 format_instructions = "NA"
9852 else :
99- output_parser = get_pydantic_output_parser (self .node_config ["schema" ])
100- format_instructions = output_parser .get_format_instructions ()
101-
53+ if not isinstance (self .llm_model , ChatBedrock ):
54+ output_parser = get_pydantic_output_parser (self .node_config ["schema" ])
55+ format_instructions = output_parser .get_format_instructions ()
56+ else :
57+ output_parser = None
58+ format_instructions = ""
10259 else :
103- output_parser = JsonOutputParser ()
104- format_instructions = output_parser .get_format_instructions ()
60+ if not isinstance (self .llm_model , ChatBedrock ):
61+ output_parser = JsonOutputParser ()
62+ format_instructions = output_parser .get_format_instructions ()
63+ else :
64+ output_parser = None
65+ format_instructions = ""
10566
10667 if isinstance (self .llm_model , (ChatOpenAI , AzureChatOpenAI )) \
10768 and not self .script_creator \
10869 or self .force \
10970 and not self .script_creator or self .is_md_scraper :
110-
111- template_no_chunks_prompt = TEMPLATE_NO_CHUNKS_MD
112- template_chunks_prompt = TEMPLATE_CHUNKS_MD
113- template_merge_prompt = TEMPLATE_MERGE_MD
71+ template_no_chunks_prompt = TEMPLATE_NO_CHUNKS_MD
72+ template_chunks_prompt = TEMPLATE_CHUNKS_MD
73+ template_merge_prompt = TEMPLATE_MERGE_MD
11474 else :
115- template_no_chunks_prompt = TEMPLATE_NO_CHUNKS
116- template_chunks_prompt = TEMPLATE_CHUNKS
117- template_merge_prompt = TEMPLATE_MERGE
75+ template_no_chunks_prompt = TEMPLATE_NO_CHUNKS
76+ template_chunks_prompt = TEMPLATE_CHUNKS
77+ template_merge_prompt = TEMPLATE_MERGE
11878
11979 if self .additional_info is not None :
120- template_no_chunks_prompt = self .additional_info + template_no_chunks_prompt
121- template_chunks_prompt = self .additional_info + template_chunks_prompt
122- template_merge_prompt = self .additional_info + template_merge_prompt
80+ template_no_chunks_prompt = self .additional_info + template_no_chunks_prompt
81+ template_chunks_prompt = self .additional_info + template_chunks_prompt
82+ template_merge_prompt = self .additional_info + template_merge_prompt
12383
12484 if len (doc ) == 1 :
12585 prompt = PromptTemplate (
126- template = template_no_chunks_prompt ,
86+ template = template_no_chunks_prompt ,
12787 input_variables = ["question" ],
128- partial_variables = {"context" : doc ,
129- "format_instructions" : format_instructions })
130- chain = prompt | self .llm_model | output_parser
88+ partial_variables = {"context" : doc , "format_instructions" : format_instructions }
89+ )
90+ chain = prompt | self .llm_model
91+ if output_parser :
92+ chain = chain | output_parser
13193 answer = chain .invoke ({"question" : user_prompt })
13294
13395 state .update ({self .output [0 ]: answer })
13496 return state
13597
13698 chains_dict = {}
13799 for i , chunk in enumerate (tqdm (doc , desc = "Processing chunks" , disable = not self .verbose )):
138-
139100 prompt = PromptTemplate (
140- template = TEMPLATE_CHUNKS ,
101+ template = template_chunks_prompt ,
141102 input_variables = ["question" ],
142- partial_variables = {"context" : chunk ,
143- "chunk_id" : i + 1 ,
144- "format_instructions" : format_instructions })
103+ partial_variables = {"context" : chunk , "chunk_id" : i + 1 , "format_instructions" : format_instructions }
104+ )
145105 chain_name = f"chunk{ i + 1 } "
146- chains_dict [chain_name ] = prompt | self .llm_model | output_parser
106+ chains_dict [chain_name ] = prompt | self .llm_model
107+ if output_parser :
108+ chains_dict [chain_name ] = chains_dict [chain_name ] | output_parser
147109
148110 async_runner = RunnableParallel (** chains_dict )
149-
150- batch_results = async_runner .invoke ({"question" : user_prompt })
111+ batch_results = async_runner .invoke ({"question" : user_prompt })
151112
152113 merge_prompt = PromptTemplate (
153- template = template_merge_prompt ,
154- input_variables = ["context" , "question" ],
155- partial_variables = {"format_instructions" : format_instructions },
156- )
114+ template = template_merge_prompt ,
115+ input_variables = ["context" , "question" ],
116+ partial_variables = {"format_instructions" : format_instructions }
117+ )
157118
158- merge_chain = merge_prompt | self .llm_model | output_parser
119+ merge_chain = merge_prompt | self .llm_model
120+ if output_parser :
121+ merge_chain = merge_chain | output_parser
159122 answer = merge_chain .invoke ({"context" : batch_results , "question" : user_prompt })
160123
161124 state .update ({self .output [0 ]: answer })
0 commit comments