1010from tqdm import tqdm
1111from ..utils .logging import get_logger
1212from .base_node import BaseNode
13- from ..prompts import template_chunks , template_no_chunks , template_merge , template_chunks_md , template_no_chunks_md , template_merge_md
13+ from ..prompts import TEMPLATE_CHUNKS , TEMPLATE_NO_CHUNKS , TEMPLATE_MERGE , TEMPLATE_CHUNKS_MD , TEMPLATE_NO_CHUNKS_MD , TEMPLATE_MERGE_MD
1414
1515class GenerateAnswerNode (BaseNode ):
1616 """
@@ -98,23 +98,23 @@ def execute(self, state: dict) -> dict:
9898
9999 format_instructions = output_parser .get_format_instructions ()
100100
101- template_no_chunks_prompt = template_no_chunks
102- template_chunks_prompt = template_chunks
103- template_merge_prompt = template_merge
104-
105101 if isinstance (self .llm_model , ChatOpenAI ) and not self .script_creator or self .force and not self .script_creator or self .is_md_scraper :
106- template_no_chunks_prompt = template_no_chunks_md
107- template_chunks_prompt = template_chunks_md
108- template_merge_prompt = template_merge_md
102+ template_no_chunks_prompt = TEMPLATE_NO_CHUNKS_MD
103+ template_chunks_prompt = TEMPLATE_CHUNKS_MD
104+ template_merge_prompt = TEMPLATE_MERGE_MD
105+ else :
106+ template_no_chunks_prompt = TEMPLATE_NO_CHUNKS
107+ template_chunks_prompt = TEMPLATE_CHUNKS
108+ template_merge_prompt = TEMPLATE_MERGE
109109
110110 if self .additional_info is not None :
111- template_no_chunks_prompt = self .additional_info + template_no_chunks_prompt
112- template_chunks_prompt = self .additional_info + template_chunks_prompt
113- template_merge_prompt = self .additional_info + template_merge_prompt
111+ template_no_chunks_prompt = self .additional_info + template_no_chunks_prompt
112+ template_chunks_prompt = self .additional_info + template_chunks_prompt
113+ template_merge_prompt = self .additional_info + template_merge_prompt
114114
115115 if len (doc ) == 1 :
116116 prompt = PromptTemplate (
117- template = template_no_chunks_prompt ,
117+ template = template_no_chunks_prompt ,
118118 input_variables = ["question" ],
119119 partial_variables = {"context" : doc ,
120120 "format_instructions" : format_instructions })
@@ -128,7 +128,7 @@ def execute(self, state: dict) -> dict:
128128 for i , chunk in enumerate (tqdm (doc , desc = "Processing chunks" , disable = not self .verbose )):
129129
130130 prompt = PromptTemplate (
131- template = template_chunks ,
131+ template = TEMPLATE_CHUNKS ,
132132 input_variables = ["question" ],
133133 partial_variables = {"context" : chunk ,
134134 "chunk_id" : i + 1 ,
@@ -141,7 +141,7 @@ def execute(self, state: dict) -> dict:
141141 batch_results = async_runner .invoke ({"question" : user_prompt })
142142
143143 merge_prompt = PromptTemplate (
144- template = template_merge_prompt ,
144+ template = template_merge_prompt ,
145145 input_variables = ["context" , "question" ],
146146 partial_variables = {"format_instructions" : format_instructions },
147147 )
0 commit comments