1616from tqdm import tqdm
1717from .base_node import BaseNode
1818from pydantic import ValidationError
19- from ..utils import transform_schema
19+ from ..utils import (transform_schema ,
20+ extract_code ,
21+ syntax_focused_analysis , syntax_focused_code_generation ,
22+ execution_focused_analysis , execution_focused_code_generation ,
23+ validation_focused_analysis , validation_focused_code_generation ,
24+ semantic_focused_analysis , semantic_focused_code_generation ,
25+ are_content_equal )
2026from jsonschema import validate , ValidationError
2127import json
22- import string
2328from ..prompts import (
24- TEMPLATE_INIT_CODE_GENERATION , TEMPLATE_SYNTAX_ANALYSIS , TEMPLATE_SYNTAX_CODE_GENERATION ,
25- TEMPLATE_EXECUTION_ANALYSIS , TEMPLATE_EXECUTION_CODE_GENERATION , TEMPLATE_VALIDATION_ANALYSIS ,
26- TEMPLATE_VALIDATION_CODE_GENERATION , TEMPLATE_SEMANTIC_COMPARISON , TEMPLATE_SEMANTIC_ANALYSIS ,
27- TEMPLATE_SEMANTIC_CODE_GENERATION
29+ TEMPLATE_INIT_CODE_GENERATION , TEMPLATE_SEMANTIC_COMPARISON
2830)
2931
3032class GenerateCodeNode (BaseNode ):
@@ -141,7 +143,7 @@ def execute(self, state: dict) -> dict:
141143 def overall_reasoning_loop (self , state : dict ) -> dict :
142144 self .logger .info (f"--- (Generating Code) ---" )
143145 state ["generated_code" ] = self .generate_initial_code (state )
144- state ["generated_code" ] = self . extract_code (state ["generated_code" ])
146+ state ["generated_code" ] = extract_code (state ["generated_code" ])
145147
146148 while state ["iteration" ] < self .max_iterations ["overall" ]:
147149 state ["iteration" ] += 1
@@ -185,10 +187,10 @@ def syntax_reasoning_loop(self, state: dict) -> dict:
185187
186188 state ["errors" ]["syntax" ] = [syntax_message ]
187189 self .logger .info (f"--- (Synax Error Found: { syntax_message } ) ---" )
188- analysis = self . syntax_focused_analysis (state )
190+ analysis = syntax_focused_analysis (state , self . llm_model )
189191 self .logger .info (f"--- (Regenerating Code to fix the Error) ---" )
190- state ["generated_code" ] = self . syntax_focused_code_generation (state , analysis )
191- state ["generated_code" ] = self . extract_code (state ["generated_code" ])
192+ state ["generated_code" ] = syntax_focused_code_generation (state , analysis , self . llm_model )
193+ state ["generated_code" ] = extract_code (state ["generated_code" ])
192194 return state
193195
194196 def execution_reasoning_loop (self , state : dict ) -> dict :
@@ -201,10 +203,10 @@ def execution_reasoning_loop(self, state: dict) -> dict:
201203
202204 state ["errors" ]["execution" ] = [execution_result ]
203205 self .logger .info (f"--- (Code Execution Error: { execution_result } ) ---" )
204- analysis = self . execution_focused_analysis (state )
206+ analysis = execution_focused_analysis (state , self . llm_model )
205207 self .logger .info (f"--- (Regenerating Code to fix the Error) ---" )
206- state ["generated_code" ] = self . execution_focused_code_generation (state , analysis )
207- state ["generated_code" ] = self . extract_code (state ["generated_code" ])
208+ state ["generated_code" ] = execution_focused_code_generation (state , analysis , self . llm_model )
209+ state ["generated_code" ] = extract_code (state ["generated_code" ])
208210 return state
209211
210212 def validation_reasoning_loop (self , state : dict ) -> dict :
@@ -216,10 +218,10 @@ def validation_reasoning_loop(self, state: dict) -> dict:
216218
217219 state ["errors" ]["validation" ] = errors
218220 self .logger .info (f"--- (Code Output not compliant to the deisred Output Schema) ---" )
219- analysis = self . validation_focused_analysis (state )
221+ analysis = validation_focused_analysis (state , self . llm_model )
220222 self .logger .info (f"--- (Regenerating Code to make the Output compliant to the deisred Output Schema) ---" )
221- state ["generated_code" ] = self . validation_focused_code_generation (state , analysis )
222- state ["generated_code" ] = self . extract_code (state ["generated_code" ])
223+ state ["generated_code" ] = validation_focused_code_generation (state , analysis , self . llm_model )
224+ state ["generated_code" ] = extract_code (state ["generated_code" ])
223225 return state
224226
225227 def semantic_comparison_loop (self , state : dict ) -> dict :
@@ -231,10 +233,10 @@ def semantic_comparison_loop(self, state: dict) -> dict:
231233
232234 state ["errors" ]["semantic" ] = comparison_result ["differences" ]
233235 self .logger .info (f"--- (The informations exctrcated are not the all ones requested) ---" )
234- analysis = self . semantic_focused_analysis (state , comparison_result )
236+ analysis = semantic_focused_analysis (state , comparison_result , self . llm_model )
235237 self .logger .info (f"--- (Regenerating Code to obtain all the infromation requested) ---" )
236- state ["generated_code" ] = self . semantic_focused_code_generation (state , analysis )
237- state ["generated_code" ] = self . extract_code (state ["generated_code" ])
238+ state ["generated_code" ] = semantic_focused_code_generation (state , analysis , self . llm_model )
239+ state ["generated_code" ] = extract_code (state ["generated_code" ])
238240 return state
239241
240242 def generate_initial_code (self , state : dict ) -> str :
@@ -254,59 +256,6 @@ def generate_initial_code(self, state: dict) -> str:
254256 generated_code = chain .invoke ({})
255257 return generated_code
256258
257- def syntax_focused_analysis (self , state : dict ) -> str :
258- prompt = PromptTemplate (template = TEMPLATE_SYNTAX_ANALYSIS , input_variables = ["generated_code" , "errors" ])
259- chain = prompt | self .llm_model | StrOutputParser ()
260- return chain .invoke ({
261- "generated_code" : state ["generated_code" ],
262- "errors" : state ["errors" ]["syntax" ]
263- })
264-
265- def syntax_focused_code_generation (self , state : dict , analysis : str ) -> str :
266- prompt = PromptTemplate (template = TEMPLATE_SYNTAX_CODE_GENERATION , input_variables = ["analysis" , "generated_code" ])
267- chain = prompt | self .llm_model | StrOutputParser ()
268- return chain .invoke ({
269- "analysis" : analysis ,
270- "generated_code" : state ["generated_code" ]
271- })
272-
273- def execution_focused_analysis (self , state : dict ) -> str :
274- prompt = PromptTemplate (template = TEMPLATE_EXECUTION_ANALYSIS , input_variables = ["generated_code" , "errors" , "html_code" , "html_analysis" ])
275- chain = prompt | self .llm_model | StrOutputParser ()
276- return chain .invoke ({
277- "generated_code" : state ["generated_code" ],
278- "errors" : state ["errors" ]["execution" ],
279- "html_code" : state ["html_code" ],
280- "html_analysis" : state ["html_analysis" ]
281- })
282-
283- def execution_focused_code_generation (self , state : dict , analysis : str ) -> str :
284- prompt = PromptTemplate (template = TEMPLATE_EXECUTION_CODE_GENERATION , input_variables = ["analysis" , "generated_code" ])
285- chain = prompt | self .llm_model | StrOutputParser ()
286- return chain .invoke ({
287- "analysis" : analysis ,
288- "generated_code" : state ["generated_code" ]
289- })
290-
291- def validation_focused_analysis (self , state : dict ) -> str :
292- prompt = PromptTemplate (template = TEMPLATE_VALIDATION_ANALYSIS , input_variables = ["generated_code" , "errors" , "json_schema" , "execution_result" ])
293- chain = prompt | self .llm_model | StrOutputParser ()
294- return chain .invoke ({
295- "generated_code" : state ["generated_code" ],
296- "errors" : state ["errors" ]["validation" ],
297- "json_schema" : state ["json_schema" ],
298- "execution_result" : state ["execution_result" ]
299- })
300-
301- def validation_focused_code_generation (self , state : dict , analysis : str ) -> str :
302- prompt = PromptTemplate (template = TEMPLATE_VALIDATION_CODE_GENERATION , input_variables = ["analysis" , "generated_code" , "json_schema" ])
303- chain = prompt | self .llm_model | StrOutputParser ()
304- return chain .invoke ({
305- "analysis" : analysis ,
306- "generated_code" : state ["generated_code" ],
307- "json_schema" : state ["json_schema" ]
308- })
309-
310259 def semantic_comparison (self , generated_result : Any , reference_result : Any ) -> Dict [str , Any ]:
311260 reference_result_dict = self .output_schema (** reference_result ).dict ()
312261
@@ -337,25 +286,6 @@ def semantic_comparison(self, generated_result: Any, reference_result: Any) -> D
337286 "reference_result" : json .dumps (reference_result_dict , indent = 2 )
338287 })
339288
340- def semantic_focused_analysis (self , state : dict , comparison_result : Dict [str , Any ]) -> str :
341- prompt = PromptTemplate (template = TEMPLATE_SEMANTIC_ANALYSIS , input_variables = ["generated_code" , "differences" , "explanation" ])
342- chain = prompt | self .llm_model | StrOutputParser ()
343- return chain .invoke ({
344- "generated_code" : state ["generated_code" ],
345- "differences" : json .dumps (comparison_result ["differences" ], indent = 2 ),
346- "explanation" : comparison_result ["explanation" ]
347- })
348-
349- def semantic_focused_code_generation (self , state : dict , analysis : str ) -> str :
350- prompt = PromptTemplate (template = TEMPLATE_SEMANTIC_CODE_GENERATION , input_variables = ["analysis" , "generated_code" , "generated_result" , "reference_result" ])
351- chain = prompt | self .llm_model | StrOutputParser ()
352- return chain .invoke ({
353- "analysis" : analysis ,
354- "generated_code" : state ["generated_code" ],
355- "generated_result" : json .dumps (state ["execution_result" ], indent = 2 ),
356- "reference_result" : json .dumps (state ["reference_answer" ], indent = 2 )
357- })
358-
359289 def syntax_check (self , code ):
360290 try :
361291 ast .parse (code )
@@ -396,39 +326,4 @@ def validate_dict(self, data: dict, schema):
396326 return True , None
397327 except ValidationError as e :
398328 errors = e .errors ()
399- return False , errors
400-
401- def extract_code (self , code : str ) -> str :
402- pattern = r'```(?:python)?\n(.*?)```'
403-
404- match = re .search (pattern , code , re .DOTALL )
405-
406- return match .group (1 ) if match else code
407-
408-
409-
410- def normalize_dict (d : Dict [str , Any ]) -> Dict [str , Any ]:
411- normalized = {}
412- for key , value in d .items ():
413- if isinstance (value , str ):
414- normalized [key ] = value .lower ().strip ()
415- elif isinstance (value , dict ):
416- normalized [key ] = normalize_dict (value )
417- elif isinstance (value , list ):
418- normalized [key ] = normalize_list (value )
419- else :
420- normalized [key ] = value
421- return normalized
422-
423- def normalize_list (lst : List [Any ]) -> List [Any ]:
424- return [
425- normalize_dict (item ) if isinstance (item , dict )
426- else normalize_list (item ) if isinstance (item , list )
427- else item .lower ().strip () if isinstance (item , str )
428- else item
429- for item in lst
430- ]
431-
432- def are_content_equal (generated_result : Dict [str , Any ], reference_result : Dict [str , Any ]) -> bool :
433- """Compare two dictionaries for semantic equality."""
434- return normalize_dict (generated_result ) == normalize_dict (reference_result )
329+ return False , errors
0 commit comments