22GenerateCodeNode Module
33"""
44from typing import Any , Dict , List , Optional
5- from langchain .prompts import PromptTemplate
6- from langchain .output_parsers import ResponseSchema , StructuredOutputParser
7- from langchain_core .output_parsers import StrOutputParser
8- from langchain_core .runnables import RunnableParallel
9- from langchain_core .utils .pydantic import is_basemodel_subclass
10- from langchain_community .chat_models import ChatOllama
115import ast
126import sys
137from io import StringIO
14- from bs4 import BeautifulSoup
158import re
16- from tqdm import tqdm
17- from .base_node import BaseNode
9+ import json
1810from pydantic import ValidationError
11+ from langchain .prompts import PromptTemplate
12+ from langchain .output_parsers import ResponseSchema , StructuredOutputParser
13+ from langchain_core .output_parsers import StrOutputParser
14+ from langchain_community .chat_models import ChatOllama
15+ from bs4 import BeautifulSoup
16+ from ..prompts import (
17+ TEMPLATE_INIT_CODE_GENERATION , TEMPLATE_SEMANTIC_COMPARISON
18+ )
1919from ..utils import (transform_schema ,
2020 extract_code ,
2121 syntax_focused_analysis , syntax_focused_code_generation ,
2222 execution_focused_analysis , execution_focused_code_generation ,
2323 validation_focused_analysis , validation_focused_code_generation ,
2424 semantic_focused_analysis , semantic_focused_code_generation ,
2525 are_content_equal )
26+ from .base_node import BaseNode
2627from jsonschema import validate , ValidationError
27- import json
28- from ..prompts import (
29- TEMPLATE_INIT_CODE_GENERATION , TEMPLATE_SEMANTIC_COMPARISON
30- )
28+
3129
3230class GenerateCodeNode (BaseNode ):
3331 """
34- A node that generates Python code for a function that extracts data from HTML based on a output schema.
32+ A node that generates Python code for a function that extracts data
33+ from HTML based on a output schema.
3534
3635 Attributes:
3736 llm_model: An instance of a language model client, configured for generating answers.
@@ -72,15 +71,15 @@ def __init__(
7271 )
7372
7473 self .additional_info = node_config .get ("additional_info" )
75-
74+
7675 self .max_iterations = node_config .get ("max_iterations" , {
7776 "overall" : 10 ,
7877 "syntax" : 3 ,
7978 "execution" : 3 ,
8079 "validation" : 3 ,
8180 "semantic" : 3
8281 })
83-
82+
8483 self .output_schema = node_config .get ("schema" )
8584
8685 def execute (self , state : dict ) -> dict :
@@ -97,25 +96,26 @@ def execute(self, state: dict) -> dict:
9796 Raises:
9897 KeyError: If the input keys are not found in the state, indicating
9998 that the necessary information for generating an answer is missing.
100- RuntimeError: If the maximum number of iterations is reached without obtaining the desired code.
99+ RuntimeError: If the maximum number of iterations is
100+ reached without obtaining the desired code.
101101 """
102-
102+
103103 self .logger .info (f"--- Executing { self .node_name } Node ---" )
104104
105105 input_keys = self .get_input_keys (state )
106-
106+
107107 input_data = [state [key ] for key in input_keys ]
108-
108+
109109 user_prompt = input_data [0 ]
110110 refined_prompt = input_data [1 ]
111111 html_info = input_data [2 ]
112112 reduced_html = input_data [3 ]
113- answer = input_data [4 ]
114-
113+ answer = input_data [4 ]
114+
115115 self .raw_html = state ['original_html' ][0 ].page_content
116-
116+
117117 simplefied_schema = str (transform_schema (self .output_schema .schema ()))
118-
118+
119119 reasoning_state = {
120120 "user_input" : user_prompt ,
121121 "json_schema" : simplefied_schema ,
@@ -133,113 +133,128 @@ def execute(self, state: dict) -> dict:
133133 },
134134 "iteration" : 0
135135 }
136-
137-
136+
138137 final_state = self .overall_reasoning_loop (reasoning_state )
139-
138+
140139 state .update ({self .output [0 ]: final_state ["generated_code" ]})
141140 return state
142-
141+
143142 def overall_reasoning_loop (self , state : dict ) -> dict :
143+ """
144+ overrall_reasoning_loop
145+ """
144146 self .logger .info (f"--- (Generating Code) ---" )
145147 state ["generated_code" ] = self .generate_initial_code (state )
146148 state ["generated_code" ] = extract_code (state ["generated_code" ])
147-
149+
148150 while state ["iteration" ] < self .max_iterations ["overall" ]:
149151 state ["iteration" ] += 1
150152 if self .verbose :
151153 self .logger .info (f"--- Iteration { state ['iteration' ]} ---" )
152-
154+
153155 self .logger .info (f"--- (Checking Code Syntax) ---" )
154156 state = self .syntax_reasoning_loop (state )
155157 if state ["errors" ]["syntax" ]:
156158 continue
157-
159+
158160 self .logger .info (f"--- (Executing the Generated Code) ---" )
159161 state = self .execution_reasoning_loop (state )
160162 if state ["errors" ]["execution" ]:
161163 continue
162-
164+
163165 self .logger .info (f"--- (Validate the Code Output Schema) ---" )
164166 state = self .validation_reasoning_loop (state )
165167 if state ["errors" ]["validation" ]:
166168 continue
167-
169+
168170 self .logger .info (f"--- (Checking if the informations exctrcated are the ones Requested) ---" )
169171 state = self .semantic_comparison_loop (state )
170172 if state ["errors" ]["semantic" ]:
171- continue
173+ continue
172174 break
173-
175+
174176 if state ["iteration" ] == self .max_iterations ["overall" ] and (state ["errors" ]["syntax" ] or state ["errors" ]["execution" ] or state ["errors" ]["validation" ] or state ["errors" ]["semantic" ]):
175177 raise RuntimeError ("Max iterations reached without obtaining the desired code." )
176-
178+
177179 self .logger .info (f"--- (Code Generated Correctly) ---" )
178-
180+
179181 return state
180-
182+
181183 def syntax_reasoning_loop (self , state : dict ) -> dict :
184+ """
185+ syntax reasoning loop
186+ """
182187 for _ in range (self .max_iterations ["syntax" ]):
183188 syntax_valid , syntax_message = self .syntax_check (state ["generated_code" ])
184189 if syntax_valid :
185190 state ["errors" ]["syntax" ] = []
186191 return state
187-
192+
188193 state ["errors" ]["syntax" ] = [syntax_message ]
189194 self .logger .info (f"--- (Synax Error Found: { syntax_message } ) ---" )
190195 analysis = syntax_focused_analysis (state , self .llm_model )
191- self .logger .info (f"--- (Regenerating Code to fix the Error) ---" )
192- state ["generated_code" ] = syntax_focused_code_generation (state , analysis , self .llm_model )
196+ self .logger .info (f"""--- (Regenerating Code
197+ to fix the Error) ---""" )
198+ state ["generated_code" ] = syntax_focused_code_generation (state ,
199+ analysis , self .llm_model )
193200 state ["generated_code" ] = extract_code (state ["generated_code" ])
194201 return state
195-
202+
196203 def execution_reasoning_loop (self , state : dict ) -> dict :
204+ """
205+ execution of the reasoning loop
206+ """
197207 for _ in range (self .max_iterations ["execution" ]):
198208 execution_success , execution_result = self .create_sandbox_and_execute (state ["generated_code" ])
199209 if execution_success :
200210 state ["execution_result" ] = execution_result
201211 state ["errors" ]["execution" ] = []
202212 return state
203-
213+
204214 state ["errors" ]["execution" ] = [execution_result ]
205215 self .logger .info (f"--- (Code Execution Error: { execution_result } ) ---" )
206216 analysis = execution_focused_analysis (state , self .llm_model )
207217 self .logger .info (f"--- (Regenerating Code to fix the Error) ---" )
208- state ["generated_code" ] = execution_focused_code_generation (state , analysis , self .llm_model )
218+ state ["generated_code" ] = execution_focused_code_generation (state ,
219+ analysis , self .llm_model )
209220 state ["generated_code" ] = extract_code (state ["generated_code" ])
210221 return state
211-
222+
212223 def validation_reasoning_loop (self , state : dict ) -> dict :
213224 for _ in range (self .max_iterations ["validation" ]):
214- validation , errors = self .validate_dict (state ["execution_result" ], self .output_schema .schema ())
225+ validation , errors = self .validate_dict (state ["execution_result" ],
226+ self .output_schema .schema ())
215227 if validation :
216228 state ["errors" ]["validation" ] = []
217229 return state
218-
230+
219231 state ["errors" ]["validation" ] = errors
220232 self .logger .info (f"--- (Code Output not compliant to the deisred Output Schema) ---" )
221233 analysis = validation_focused_analysis (state , self .llm_model )
222234 self .logger .info (f"--- (Regenerating Code to make the Output compliant to the deisred Output Schema) ---" )
223235 state ["generated_code" ] = validation_focused_code_generation (state , analysis , self .llm_model )
224236 state ["generated_code" ] = extract_code (state ["generated_code" ])
225237 return state
226-
238+
227239 def semantic_comparison_loop (self , state : dict ) -> dict :
228240 for _ in range (self .max_iterations ["semantic" ]):
229241 comparison_result = self .semantic_comparison (state ["execution_result" ], state ["reference_answer" ])
230242 if comparison_result ["are_semantically_equivalent" ]:
231243 state ["errors" ]["semantic" ] = []
232244 return state
233-
245+
234246 state ["errors" ]["semantic" ] = comparison_result ["differences" ]
235247 self .logger .info (f"--- (The informations exctrcated are not the all ones requested) ---" )
236248 analysis = semantic_focused_analysis (state , comparison_result , self .llm_model )
237249 self .logger .info (f"--- (Regenerating Code to obtain all the infromation requested) ---" )
238250 state ["generated_code" ] = semantic_focused_code_generation (state , analysis , self .llm_model )
239251 state ["generated_code" ] = extract_code (state ["generated_code" ])
240252 return state
241-
253+
242254 def generate_initial_code (self , state : dict ) -> str :
255+ """
256+ function for generating the initial code
257+ """
243258 prompt = PromptTemplate (
244259 template = TEMPLATE_INIT_CODE_GENERATION ,
245260 partial_variables = {
@@ -255,22 +270,29 @@ def generate_initial_code(self, state: dict) -> str:
255270 chain = prompt | self .llm_model | output_parser
256271 generated_code = chain .invoke ({})
257272 return generated_code
258-
273+
259274 def semantic_comparison (self , generated_result : Any , reference_result : Any ) -> Dict [str , Any ]:
275+ """
276+ semtantic comparison formula
277+ """
260278 reference_result_dict = self .output_schema (** reference_result ).dict ()
261-
262- # Check if generated result and reference result are actually equal
263279 if are_content_equal (generated_result , reference_result_dict ):
264280 return {
265281 "are_semantically_equivalent" : True ,
266282 "differences" : [],
267283 "explanation" : "The generated result and reference result are exactly equal."
268284 }
269-
285+
270286 response_schemas = [
271- ResponseSchema (name = "are_semantically_equivalent" , description = "Boolean indicating if the results are semantically equivalent" ),
272- ResponseSchema (name = "differences" , description = "List of semantic differences between the results, if any" ),
273- ResponseSchema (name = "explanation" , description = "Detailed explanation of the comparison and reasoning" )
287+ ResponseSchema (name = "are_semantically_equivalent" ,
288+ description = """Boolean indicating if the
289+ results are semantically equivalent""" ),
290+ ResponseSchema (name = "differences" ,
291+ description = """List of semantic differences
292+ between the results, if any""" ),
293+ ResponseSchema (name = "explanation" ,
294+ description = """Detailed explanation of the
295+ comparison and reasoning""" )
274296 ]
275297 output_parser = StructuredOutputParser .from_response_schemas (response_schemas )
276298
@@ -285,45 +307,52 @@ def semantic_comparison(self, generated_result: Any, reference_result: Any) -> D
285307 "generated_result" : json .dumps (generated_result , indent = 2 ),
286308 "reference_result" : json .dumps (reference_result_dict , indent = 2 )
287309 })
288-
310+
289311 def syntax_check (self , code ):
312+ """
313+ syntax checker
314+ """
290315 try :
291316 ast .parse (code )
292317 return True , "Syntax is correct."
293318 except SyntaxError as e :
294319 return False , f"Syntax error: { str (e )} "
295320
296321 def create_sandbox_and_execute (self , function_code ):
297- # Create a sandbox environment
322+ """
323+ Create a sandbox environment
324+ """
298325 sandbox_globals = {
299326 'BeautifulSoup' : BeautifulSoup ,
300327 're' : re ,
301328 '__builtins__' : __builtins__ ,
302329 }
303-
330+
304331 old_stdout = sys .stdout
305332 sys .stdout = StringIO ()
306-
333+
307334 try :
308335 exec (function_code , sandbox_globals )
309-
336+
310337 extract_data = sandbox_globals .get ('extract_data' )
311-
338+
312339 if not extract_data :
313340 raise NameError ("Function 'extract_data' not found in the generated code." )
314-
315- result = extract_data (self .raw_html )
316-
341+
342+ result = extract_data (self .raw_html )
317343 return True , result
318344 except Exception as e :
319345 return False , f"Error during execution: { str (e )} "
320346 finally :
321347 sys .stdout = old_stdout
322-
348+
323349 def validate_dict (self , data : dict , schema ):
350+ """
351+ validate_dict method
352+ """
324353 try :
325354 validate (instance = data , schema = schema )
326355 return True , None
327356 except ValidationError as e :
328357 errors = e .errors ()
329- return False , errors
358+ return False , errors
0 commit comments