Skip to content

Commit ce841e2

Browse files
Code generation refactoring
1 parent 2d2c719 commit ce841e2

File tree

8 files changed

+168
-128
lines changed

8 files changed

+168
-128
lines changed

scrapegraphai/nodes/generate_code_node.py

Lines changed: 22 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,17 @@
1616
from tqdm import tqdm
1717
from .base_node import BaseNode
1818
from pydantic import ValidationError
19-
from ..utils import transform_schema
19+
from ..utils import (transform_schema,
20+
extract_code,
21+
syntax_focused_analysis, syntax_focused_code_generation,
22+
execution_focused_analysis, execution_focused_code_generation,
23+
validation_focused_analysis, validation_focused_code_generation,
24+
semantic_focused_analysis, semantic_focused_code_generation,
25+
are_content_equal)
2026
from jsonschema import validate, ValidationError
2127
import json
22-
import string
2328
from ..prompts import (
24-
TEMPLATE_INIT_CODE_GENERATION, TEMPLATE_SYNTAX_ANALYSIS, TEMPLATE_SYNTAX_CODE_GENERATION,
25-
TEMPLATE_EXECUTION_ANALYSIS, TEMPLATE_EXECUTION_CODE_GENERATION, TEMPLATE_VALIDATION_ANALYSIS,
26-
TEMPLATE_VALIDATION_CODE_GENERATION, TEMPLATE_SEMANTIC_COMPARISON, TEMPLATE_SEMANTIC_ANALYSIS,
27-
TEMPLATE_SEMANTIC_CODE_GENERATION
29+
TEMPLATE_INIT_CODE_GENERATION, TEMPLATE_SEMANTIC_COMPARISON
2830
)
2931

3032
class GenerateCodeNode(BaseNode):
@@ -141,7 +143,7 @@ def execute(self, state: dict) -> dict:
141143
def overall_reasoning_loop(self, state: dict) -> dict:
142144
self.logger.info(f"--- (Generating Code) ---")
143145
state["generated_code"] = self.generate_initial_code(state)
144-
state["generated_code"] = self.extract_code(state["generated_code"])
146+
state["generated_code"] = extract_code(state["generated_code"])
145147

146148
while state["iteration"] < self.max_iterations["overall"]:
147149
state["iteration"] += 1
@@ -185,10 +187,10 @@ def syntax_reasoning_loop(self, state: dict) -> dict:
185187

186188
state["errors"]["syntax"] = [syntax_message]
187189
self.logger.info(f"--- (Synax Error Found: {syntax_message}) ---")
188-
analysis = self.syntax_focused_analysis(state)
190+
analysis = syntax_focused_analysis(state, self.llm_model)
189191
self.logger.info(f"--- (Regenerating Code to fix the Error) ---")
190-
state["generated_code"] = self.syntax_focused_code_generation(state, analysis)
191-
state["generated_code"] = self.extract_code(state["generated_code"])
192+
state["generated_code"] = syntax_focused_code_generation(state, analysis, self.llm_model)
193+
state["generated_code"] = extract_code(state["generated_code"])
192194
return state
193195

194196
def execution_reasoning_loop(self, state: dict) -> dict:
@@ -201,10 +203,10 @@ def execution_reasoning_loop(self, state: dict) -> dict:
201203

202204
state["errors"]["execution"] = [execution_result]
203205
self.logger.info(f"--- (Code Execution Error: {execution_result}) ---")
204-
analysis = self.execution_focused_analysis(state)
206+
analysis = execution_focused_analysis(state, self.llm_model)
205207
self.logger.info(f"--- (Regenerating Code to fix the Error) ---")
206-
state["generated_code"] = self.execution_focused_code_generation(state, analysis)
207-
state["generated_code"] = self.extract_code(state["generated_code"])
208+
state["generated_code"] = execution_focused_code_generation(state, analysis, self.llm_model)
209+
state["generated_code"] = extract_code(state["generated_code"])
208210
return state
209211

210212
def validation_reasoning_loop(self, state: dict) -> dict:
@@ -216,10 +218,10 @@ def validation_reasoning_loop(self, state: dict) -> dict:
216218

217219
state["errors"]["validation"] = errors
218220
self.logger.info(f"--- (Code Output not compliant to the deisred Output Schema) ---")
219-
analysis = self.validation_focused_analysis(state)
221+
analysis = validation_focused_analysis(state, self.llm_model)
220222
self.logger.info(f"--- (Regenerating Code to make the Output compliant to the deisred Output Schema) ---")
221-
state["generated_code"] = self.validation_focused_code_generation(state, analysis)
222-
state["generated_code"] = self.extract_code(state["generated_code"])
223+
state["generated_code"] = validation_focused_code_generation(state, analysis, self.llm_model)
224+
state["generated_code"] = extract_code(state["generated_code"])
223225
return state
224226

225227
def semantic_comparison_loop(self, state: dict) -> dict:
@@ -231,10 +233,10 @@ def semantic_comparison_loop(self, state: dict) -> dict:
231233

232234
state["errors"]["semantic"] = comparison_result["differences"]
233235
self.logger.info(f"--- (The informations exctrcated are not the all ones requested) ---")
234-
analysis = self.semantic_focused_analysis(state, comparison_result)
236+
analysis = semantic_focused_analysis(state, comparison_result, self.llm_model)
235237
self.logger.info(f"--- (Regenerating Code to obtain all the infromation requested) ---")
236-
state["generated_code"] = self.semantic_focused_code_generation(state, analysis)
237-
state["generated_code"] = self.extract_code(state["generated_code"])
238+
state["generated_code"] = semantic_focused_code_generation(state, analysis, self.llm_model)
239+
state["generated_code"] = extract_code(state["generated_code"])
238240
return state
239241

240242
def generate_initial_code(self, state: dict) -> str:
@@ -254,59 +256,6 @@ def generate_initial_code(self, state: dict) -> str:
254256
generated_code = chain.invoke({})
255257
return generated_code
256258

257-
def syntax_focused_analysis(self, state: dict) -> str:
258-
prompt = PromptTemplate(template=TEMPLATE_SYNTAX_ANALYSIS, input_variables=["generated_code", "errors"])
259-
chain = prompt | self.llm_model | StrOutputParser()
260-
return chain.invoke({
261-
"generated_code": state["generated_code"],
262-
"errors": state["errors"]["syntax"]
263-
})
264-
265-
def syntax_focused_code_generation(self, state: dict, analysis: str) -> str:
266-
prompt = PromptTemplate(template=TEMPLATE_SYNTAX_CODE_GENERATION, input_variables=["analysis", "generated_code"])
267-
chain = prompt | self.llm_model | StrOutputParser()
268-
return chain.invoke({
269-
"analysis": analysis,
270-
"generated_code": state["generated_code"]
271-
})
272-
273-
def execution_focused_analysis(self, state: dict) -> str:
274-
prompt = PromptTemplate(template=TEMPLATE_EXECUTION_ANALYSIS, input_variables=["generated_code", "errors", "html_code", "html_analysis"])
275-
chain = prompt | self.llm_model | StrOutputParser()
276-
return chain.invoke({
277-
"generated_code": state["generated_code"],
278-
"errors": state["errors"]["execution"],
279-
"html_code": state["html_code"],
280-
"html_analysis": state["html_analysis"]
281-
})
282-
283-
def execution_focused_code_generation(self, state: dict, analysis: str) -> str:
284-
prompt = PromptTemplate(template=TEMPLATE_EXECUTION_CODE_GENERATION, input_variables=["analysis", "generated_code"])
285-
chain = prompt | self.llm_model | StrOutputParser()
286-
return chain.invoke({
287-
"analysis": analysis,
288-
"generated_code": state["generated_code"]
289-
})
290-
291-
def validation_focused_analysis(self, state: dict) -> str:
292-
prompt = PromptTemplate(template=TEMPLATE_VALIDATION_ANALYSIS, input_variables=["generated_code", "errors", "json_schema", "execution_result"])
293-
chain = prompt | self.llm_model | StrOutputParser()
294-
return chain.invoke({
295-
"generated_code": state["generated_code"],
296-
"errors": state["errors"]["validation"],
297-
"json_schema": state["json_schema"],
298-
"execution_result": state["execution_result"]
299-
})
300-
301-
def validation_focused_code_generation(self, state: dict, analysis: str) -> str:
302-
prompt = PromptTemplate(template=TEMPLATE_VALIDATION_CODE_GENERATION, input_variables=["analysis", "generated_code", "json_schema"])
303-
chain = prompt | self.llm_model | StrOutputParser()
304-
return chain.invoke({
305-
"analysis": analysis,
306-
"generated_code": state["generated_code"],
307-
"json_schema": state["json_schema"]
308-
})
309-
310259
def semantic_comparison(self, generated_result: Any, reference_result: Any) -> Dict[str, Any]:
311260
reference_result_dict = self.output_schema(**reference_result).dict()
312261

@@ -337,25 +286,6 @@ def semantic_comparison(self, generated_result: Any, reference_result: Any) -> D
337286
"reference_result": json.dumps(reference_result_dict, indent=2)
338287
})
339288

340-
def semantic_focused_analysis(self, state: dict, comparison_result: Dict[str, Any]) -> str:
341-
prompt = PromptTemplate(template=TEMPLATE_SEMANTIC_ANALYSIS, input_variables=["generated_code", "differences", "explanation"])
342-
chain = prompt | self.llm_model | StrOutputParser()
343-
return chain.invoke({
344-
"generated_code": state["generated_code"],
345-
"differences": json.dumps(comparison_result["differences"], indent=2),
346-
"explanation": comparison_result["explanation"]
347-
})
348-
349-
def semantic_focused_code_generation(self, state: dict, analysis: str) -> str:
350-
prompt = PromptTemplate(template=TEMPLATE_SEMANTIC_CODE_GENERATION, input_variables=["analysis", "generated_code", "generated_result", "reference_result"])
351-
chain = prompt | self.llm_model | StrOutputParser()
352-
return chain.invoke({
353-
"analysis": analysis,
354-
"generated_code": state["generated_code"],
355-
"generated_result": json.dumps(state["execution_result"], indent=2),
356-
"reference_result": json.dumps(state["reference_answer"], indent=2)
357-
})
358-
359289
def syntax_check(self, code):
360290
try:
361291
ast.parse(code)
@@ -396,39 +326,4 @@ def validate_dict(self, data: dict, schema):
396326
return True, None
397327
except ValidationError as e:
398328
errors = e.errors()
399-
return False, errors
400-
401-
def extract_code(self, code: str) -> str:
402-
pattern = r'```(?:python)?\n(.*?)```'
403-
404-
match = re.search(pattern, code, re.DOTALL)
405-
406-
return match.group(1) if match else code
407-
408-
409-
410-
def normalize_dict(d: Dict[str, Any]) -> Dict[str, Any]:
411-
normalized = {}
412-
for key, value in d.items():
413-
if isinstance(value, str):
414-
normalized[key] = value.lower().strip()
415-
elif isinstance(value, dict):
416-
normalized[key] = normalize_dict(value)
417-
elif isinstance(value, list):
418-
normalized[key] = normalize_list(value)
419-
else:
420-
normalized[key] = value
421-
return normalized
422-
423-
def normalize_list(lst: List[Any]) -> List[Any]:
424-
return [
425-
normalize_dict(item) if isinstance(item, dict)
426-
else normalize_list(item) if isinstance(item, list)
427-
else item.lower().strip() if isinstance(item, str)
428-
else item
429-
for item in lst
430-
]
431-
432-
def are_content_equal(generated_result: Dict[str, Any], reference_result: Dict[str, Any]) -> bool:
433-
"""Compare two dictionaries for semantic equality."""
434-
return normalize_dict(generated_result) == normalize_dict(reference_result)
329+
return False, errors

scrapegraphai/prompts/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,9 @@
1313
from .search_node_with_context_prompts import TEMPLATE_SEARCH_WITH_CONTEXT_CHUNKS, TEMPLATE_SEARCH_WITH_CONTEXT_NO_CHUNKS
1414
from .prompt_refiner_node_prompts import TEMPLATE_REFINER, TEMPLATE_REFINER_WITH_CONTEXT
1515
from .html_analyzer_node_prompts import TEMPLATE_HTML_ANALYSIS, TEMPLATE_HTML_ANALYSIS_WITH_CONTEXT
16-
from .generate_code_node_prompts import TEMPLATE_INIT_CODE_GENERATION, TEMPLATE_SYNTAX_ANALYSIS, TEMPLATE_SYNTAX_CODE_GENERATION, TEMPLATE_EXECUTION_ANALYSIS, TEMPLATE_EXECUTION_CODE_GENERATION, TEMPLATE_VALIDATION_ANALYSIS, TEMPLATE_VALIDATION_CODE_GENERATION, TEMPLATE_SEMANTIC_COMPARISON, TEMPLATE_SEMANTIC_ANALYSIS, TEMPLATE_SEMANTIC_CODE_GENERATION
16+
from .generate_code_node_prompts import (TEMPLATE_INIT_CODE_GENERATION,
17+
TEMPLATE_SYNTAX_ANALYSIS, TEMPLATE_SYNTAX_CODE_GENERATION,
18+
TEMPLATE_EXECUTION_ANALYSIS, TEMPLATE_EXECUTION_CODE_GENERATION,
19+
TEMPLATE_VALIDATION_ANALYSIS, TEMPLATE_VALIDATION_CODE_GENERATION,
20+
TEMPLATE_SEMANTIC_COMPARISON, TEMPLATE_SEMANTIC_ANALYSIS,
21+
TEMPLATE_SEMANTIC_CODE_GENERATION)

scrapegraphai/utils/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,9 @@
1919
from .split_text_into_chunks import split_text_into_chunks
2020
from .llm_callback_manager import CustomLLMCallbackManager
2121
from .schema_trasform import transform_schema
22+
from .cleanup_code import extract_code
23+
from .dict_content_compare import are_content_equal
24+
from .code_error_analysis import (syntax_focused_analysis, execution_focused_analysis,
25+
validation_focused_analysis, semantic_focused_analysis)
26+
from .code_error_correction import (syntax_focused_code_generation, execution_focused_code_generation,
27+
validation_focused_code_generation, semantic_focused_code_generation)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""
2+
This utility function extracts the code from a given string.
3+
"""
4+
import re
5+
6+
def extract_code(code: str) -> str:
7+
pattern = r'```(?:python)?\n(.*?)```'
8+
9+
match = re.search(pattern, code, re.DOTALL)
10+
11+
return match.group(1) if match else code
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""
2+
This module contains the functions that are used to generate the prompts for the code error analysis.
3+
"""
4+
from typing import Any, Dict
5+
from langchain.prompts import PromptTemplate
6+
from langchain_core.output_parsers import StrOutputParser
7+
import json
8+
from ..prompts import (
9+
TEMPLATE_SYNTAX_ANALYSIS, TEMPLATE_EXECUTION_ANALYSIS,
10+
TEMPLATE_VALIDATION_ANALYSIS, TEMPLATE_SEMANTIC_ANALYSIS
11+
)
12+
13+
def syntax_focused_analysis(state: dict, llm_model) -> str:
14+
prompt = PromptTemplate(template=TEMPLATE_SYNTAX_ANALYSIS, input_variables=["generated_code", "errors"])
15+
chain = prompt | llm_model | StrOutputParser()
16+
return chain.invoke({
17+
"generated_code": state["generated_code"],
18+
"errors": state["errors"]["syntax"]
19+
})
20+
21+
def execution_focused_analysis(state: dict, llm_model) -> str:
22+
prompt = PromptTemplate(template=TEMPLATE_EXECUTION_ANALYSIS, input_variables=["generated_code", "errors", "html_code", "html_analysis"])
23+
chain = prompt | llm_model | StrOutputParser()
24+
return chain.invoke({
25+
"generated_code": state["generated_code"],
26+
"errors": state["errors"]["execution"],
27+
"html_code": state["html_code"],
28+
"html_analysis": state["html_analysis"]
29+
})
30+
31+
def validation_focused_analysis(state: dict, llm_model) -> str:
32+
prompt = PromptTemplate(template=TEMPLATE_VALIDATION_ANALYSIS, input_variables=["generated_code", "errors", "json_schema", "execution_result"])
33+
chain = prompt | llm_model | StrOutputParser()
34+
return chain.invoke({
35+
"generated_code": state["generated_code"],
36+
"errors": state["errors"]["validation"],
37+
"json_schema": state["json_schema"],
38+
"execution_result": state["execution_result"]
39+
})
40+
41+
def semantic_focused_analysis(state: dict, comparison_result: Dict[str, Any], llm_model) -> str:
42+
prompt = PromptTemplate(template=TEMPLATE_SEMANTIC_ANALYSIS, input_variables=["generated_code", "differences", "explanation"])
43+
chain = prompt | llm_model | StrOutputParser()
44+
return chain.invoke({
45+
"generated_code": state["generated_code"],
46+
"differences": json.dumps(comparison_result["differences"], indent=2),
47+
"explanation": comparison_result["explanation"]
48+
})
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""
2+
This module contains the code generation functions for code correction for different types errors.
3+
"""
4+
from langchain.prompts import PromptTemplate
5+
from langchain_core.output_parsers import StrOutputParser
6+
import json
7+
from ..prompts import (
8+
TEMPLATE_SYNTAX_CODE_GENERATION, TEMPLATE_EXECUTION_CODE_GENERATION,
9+
TEMPLATE_VALIDATION_CODE_GENERATION, TEMPLATE_SEMANTIC_CODE_GENERATION
10+
)
11+
12+
def syntax_focused_code_generation(state: dict, analysis: str, llm_model) -> str:
13+
prompt = PromptTemplate(template=TEMPLATE_SYNTAX_CODE_GENERATION, input_variables=["analysis", "generated_code"])
14+
chain = prompt | llm_model | StrOutputParser()
15+
return chain.invoke({
16+
"analysis": analysis,
17+
"generated_code": state["generated_code"]
18+
})
19+
20+
def execution_focused_code_generation(state: dict, analysis: str, llm_model) -> str:
21+
prompt = PromptTemplate(template=TEMPLATE_EXECUTION_CODE_GENERATION, input_variables=["analysis", "generated_code"])
22+
chain = prompt | llm_model | StrOutputParser()
23+
return chain.invoke({
24+
"analysis": analysis,
25+
"generated_code": state["generated_code"]
26+
})
27+
28+
def validation_focused_code_generation(state: dict, analysis: str, llm_model) -> str:
29+
prompt = PromptTemplate(template=TEMPLATE_VALIDATION_CODE_GENERATION, input_variables=["analysis", "generated_code", "json_schema"])
30+
chain = prompt | llm_model | StrOutputParser()
31+
return chain.invoke({
32+
"analysis": analysis,
33+
"generated_code": state["generated_code"],
34+
"json_schema": state["json_schema"]
35+
})
36+
37+
def semantic_focused_code_generation(state: dict, analysis: str, llm_model) -> str:
38+
prompt = PromptTemplate(template=TEMPLATE_SEMANTIC_CODE_GENERATION, input_variables=["analysis", "generated_code", "generated_result", "reference_result"])
39+
chain = prompt | llm_model | StrOutputParser()
40+
return chain.invoke({
41+
"analysis": analysis,
42+
"generated_code": state["generated_code"],
43+
"generated_result": json.dumps(state["execution_result"], indent=2),
44+
"reference_result": json.dumps(state["reference_answer"], indent=2)
45+
})
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""
2+
Utility functions for comparing the content of two dictionaries.
3+
"""
4+
from typing import Any, Dict, List
5+
6+
def normalize_dict(d: Dict[str, Any]) -> Dict[str, Any]:
7+
normalized = {}
8+
for key, value in d.items():
9+
if isinstance(value, str):
10+
normalized[key] = value.lower().strip()
11+
elif isinstance(value, dict):
12+
normalized[key] = normalize_dict(value)
13+
elif isinstance(value, list):
14+
normalized[key] = normalize_list(value)
15+
else:
16+
normalized[key] = value
17+
return normalized
18+
19+
def normalize_list(lst: List[Any]) -> List[Any]:
20+
return [
21+
normalize_dict(item) if isinstance(item, dict)
22+
else normalize_list(item) if isinstance(item, list)
23+
else item.lower().strip() if isinstance(item, str)
24+
else item
25+
for item in lst
26+
]
27+
28+
def are_content_equal(generated_result: Dict[str, Any], reference_result: Dict[str, Any]) -> bool:
29+
"""Compare two dictionaries for semantic equality."""
30+
return normalize_dict(generated_result) == normalize_dict(reference_result)

0 commit comments

Comments
 (0)