Skip to content

Commit 7c39b06

Browse files
committed
update for pylint
1 parent c181fea commit 7c39b06

File tree

2 files changed

+99
-75
lines changed

2 files changed

+99
-75
lines changed

scrapegraphai/nodes/generate_code_node.py

Lines changed: 97 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -2,36 +2,35 @@
22
GenerateCodeNode Module
33
"""
44
from typing import Any, Dict, List, Optional
5-
from langchain.prompts import PromptTemplate
6-
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
7-
from langchain_core.output_parsers import StrOutputParser
8-
from langchain_core.runnables import RunnableParallel
9-
from langchain_core.utils.pydantic import is_basemodel_subclass
10-
from langchain_community.chat_models import ChatOllama
115
import ast
126
import sys
137
from io import StringIO
14-
from bs4 import BeautifulSoup
158
import re
16-
from tqdm import tqdm
17-
from .base_node import BaseNode
9+
import json
1810
from pydantic import ValidationError
11+
from langchain.prompts import PromptTemplate
12+
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
13+
from langchain_core.output_parsers import StrOutputParser
14+
from langchain_community.chat_models import ChatOllama
15+
from bs4 import BeautifulSoup
16+
from ..prompts import (
17+
TEMPLATE_INIT_CODE_GENERATION, TEMPLATE_SEMANTIC_COMPARISON
18+
)
1919
from ..utils import (transform_schema,
2020
extract_code,
2121
syntax_focused_analysis, syntax_focused_code_generation,
2222
execution_focused_analysis, execution_focused_code_generation,
2323
validation_focused_analysis, validation_focused_code_generation,
2424
semantic_focused_analysis, semantic_focused_code_generation,
2525
are_content_equal)
26+
from .base_node import BaseNode
2627
from jsonschema import validate, ValidationError
27-
import json
28-
from ..prompts import (
29-
TEMPLATE_INIT_CODE_GENERATION, TEMPLATE_SEMANTIC_COMPARISON
30-
)
28+
3129

3230
class GenerateCodeNode(BaseNode):
3331
"""
34-
A node that generates Python code for a function that extracts data from HTML based on a output schema.
32+
A node that generates Python code for a function that extracts data
33+
from HTML based on a output schema.
3534
3635
Attributes:
3736
llm_model: An instance of a language model client, configured for generating answers.
@@ -72,15 +71,15 @@ def __init__(
7271
)
7372

7473
self.additional_info = node_config.get("additional_info")
75-
74+
7675
self.max_iterations = node_config.get("max_iterations", {
7776
"overall": 10,
7877
"syntax": 3,
7978
"execution": 3,
8079
"validation": 3,
8180
"semantic": 3
8281
})
83-
82+
8483
self.output_schema = node_config.get("schema")
8584

8685
def execute(self, state: dict) -> dict:
@@ -97,25 +96,26 @@ def execute(self, state: dict) -> dict:
9796
Raises:
9897
KeyError: If the input keys are not found in the state, indicating
9998
that the necessary information for generating an answer is missing.
100-
RuntimeError: If the maximum number of iterations is reached without obtaining the desired code.
99+
RuntimeError: If the maximum number of iterations is
100+
reached without obtaining the desired code.
101101
"""
102-
102+
103103
self.logger.info(f"--- Executing {self.node_name} Node ---")
104104

105105
input_keys = self.get_input_keys(state)
106-
106+
107107
input_data = [state[key] for key in input_keys]
108-
108+
109109
user_prompt = input_data[0]
110110
refined_prompt = input_data[1]
111111
html_info = input_data[2]
112112
reduced_html = input_data[3]
113-
answer = input_data[4]
114-
113+
answer = input_data[4]
114+
115115
self.raw_html = state['original_html'][0].page_content
116-
116+
117117
simplefied_schema = str(transform_schema(self.output_schema.schema()))
118-
118+
119119
reasoning_state = {
120120
"user_input": user_prompt,
121121
"json_schema": simplefied_schema,
@@ -133,113 +133,128 @@ def execute(self, state: dict) -> dict:
133133
},
134134
"iteration": 0
135135
}
136-
137-
136+
138137
final_state = self.overall_reasoning_loop(reasoning_state)
139-
138+
140139
state.update({self.output[0]: final_state["generated_code"]})
141140
return state
142-
141+
143142
def overall_reasoning_loop(self, state: dict) -> dict:
143+
"""
144+
overrall_reasoning_loop
145+
"""
144146
self.logger.info(f"--- (Generating Code) ---")
145147
state["generated_code"] = self.generate_initial_code(state)
146148
state["generated_code"] = extract_code(state["generated_code"])
147-
149+
148150
while state["iteration"] < self.max_iterations["overall"]:
149151
state["iteration"] += 1
150152
if self.verbose:
151153
self.logger.info(f"--- Iteration {state['iteration']} ---")
152-
154+
153155
self.logger.info(f"--- (Checking Code Syntax) ---")
154156
state = self.syntax_reasoning_loop(state)
155157
if state["errors"]["syntax"]:
156158
continue
157-
159+
158160
self.logger.info(f"--- (Executing the Generated Code) ---")
159161
state = self.execution_reasoning_loop(state)
160162
if state["errors"]["execution"]:
161163
continue
162-
164+
163165
self.logger.info(f"--- (Validate the Code Output Schema) ---")
164166
state = self.validation_reasoning_loop(state)
165167
if state["errors"]["validation"]:
166168
continue
167-
169+
168170
self.logger.info(f"--- (Checking if the informations exctrcated are the ones Requested) ---")
169171
state = self.semantic_comparison_loop(state)
170172
if state["errors"]["semantic"]:
171-
continue
173+
continue
172174
break
173-
175+
174176
if state["iteration"] == self.max_iterations["overall"] and (state["errors"]["syntax"] or state["errors"]["execution"] or state["errors"]["validation"] or state["errors"]["semantic"]):
175177
raise RuntimeError("Max iterations reached without obtaining the desired code.")
176-
178+
177179
self.logger.info(f"--- (Code Generated Correctly) ---")
178-
180+
179181
return state
180-
182+
181183
def syntax_reasoning_loop(self, state: dict) -> dict:
184+
"""
185+
syntax reasoning loop
186+
"""
182187
for _ in range(self.max_iterations["syntax"]):
183188
syntax_valid, syntax_message = self.syntax_check(state["generated_code"])
184189
if syntax_valid:
185190
state["errors"]["syntax"] = []
186191
return state
187-
192+
188193
state["errors"]["syntax"] = [syntax_message]
189194
self.logger.info(f"--- (Synax Error Found: {syntax_message}) ---")
190195
analysis = syntax_focused_analysis(state, self.llm_model)
191-
self.logger.info(f"--- (Regenerating Code to fix the Error) ---")
192-
state["generated_code"] = syntax_focused_code_generation(state, analysis, self.llm_model)
196+
self.logger.info(f"""--- (Regenerating Code
197+
to fix the Error) ---""")
198+
state["generated_code"] = syntax_focused_code_generation(state,
199+
analysis, self.llm_model)
193200
state["generated_code"] = extract_code(state["generated_code"])
194201
return state
195-
202+
196203
def execution_reasoning_loop(self, state: dict) -> dict:
204+
"""
205+
execution of the reasoning loop
206+
"""
197207
for _ in range(self.max_iterations["execution"]):
198208
execution_success, execution_result = self.create_sandbox_and_execute(state["generated_code"])
199209
if execution_success:
200210
state["execution_result"] = execution_result
201211
state["errors"]["execution"] = []
202212
return state
203-
213+
204214
state["errors"]["execution"] = [execution_result]
205215
self.logger.info(f"--- (Code Execution Error: {execution_result}) ---")
206216
analysis = execution_focused_analysis(state, self.llm_model)
207217
self.logger.info(f"--- (Regenerating Code to fix the Error) ---")
208-
state["generated_code"] = execution_focused_code_generation(state, analysis, self.llm_model)
218+
state["generated_code"] = execution_focused_code_generation(state,
219+
analysis, self.llm_model)
209220
state["generated_code"] = extract_code(state["generated_code"])
210221
return state
211-
222+
212223
def validation_reasoning_loop(self, state: dict) -> dict:
213224
for _ in range(self.max_iterations["validation"]):
214-
validation, errors = self.validate_dict(state["execution_result"], self.output_schema.schema())
225+
validation, errors = self.validate_dict(state["execution_result"],
226+
self.output_schema.schema())
215227
if validation:
216228
state["errors"]["validation"] = []
217229
return state
218-
230+
219231
state["errors"]["validation"] = errors
220232
self.logger.info(f"--- (Code Output not compliant to the deisred Output Schema) ---")
221233
analysis = validation_focused_analysis(state, self.llm_model)
222234
self.logger.info(f"--- (Regenerating Code to make the Output compliant to the deisred Output Schema) ---")
223235
state["generated_code"] = validation_focused_code_generation(state, analysis, self.llm_model)
224236
state["generated_code"] = extract_code(state["generated_code"])
225237
return state
226-
238+
227239
def semantic_comparison_loop(self, state: dict) -> dict:
228240
for _ in range(self.max_iterations["semantic"]):
229241
comparison_result = self.semantic_comparison(state["execution_result"], state["reference_answer"])
230242
if comparison_result["are_semantically_equivalent"]:
231243
state["errors"]["semantic"] = []
232244
return state
233-
245+
234246
state["errors"]["semantic"] = comparison_result["differences"]
235247
self.logger.info(f"--- (The informations exctrcated are not the all ones requested) ---")
236248
analysis = semantic_focused_analysis(state, comparison_result, self.llm_model)
237249
self.logger.info(f"--- (Regenerating Code to obtain all the infromation requested) ---")
238250
state["generated_code"] = semantic_focused_code_generation(state, analysis, self.llm_model)
239251
state["generated_code"] = extract_code(state["generated_code"])
240252
return state
241-
253+
242254
def generate_initial_code(self, state: dict) -> str:
255+
"""
256+
function for generating the initial code
257+
"""
243258
prompt = PromptTemplate(
244259
template=TEMPLATE_INIT_CODE_GENERATION,
245260
partial_variables={
@@ -255,22 +270,29 @@ def generate_initial_code(self, state: dict) -> str:
255270
chain = prompt | self.llm_model | output_parser
256271
generated_code = chain.invoke({})
257272
return generated_code
258-
273+
259274
def semantic_comparison(self, generated_result: Any, reference_result: Any) -> Dict[str, Any]:
275+
"""
276+
semtantic comparison formula
277+
"""
260278
reference_result_dict = self.output_schema(**reference_result).dict()
261-
262-
# Check if generated result and reference result are actually equal
263279
if are_content_equal(generated_result, reference_result_dict):
264280
return {
265281
"are_semantically_equivalent": True,
266282
"differences": [],
267283
"explanation": "The generated result and reference result are exactly equal."
268284
}
269-
285+
270286
response_schemas = [
271-
ResponseSchema(name="are_semantically_equivalent", description="Boolean indicating if the results are semantically equivalent"),
272-
ResponseSchema(name="differences", description="List of semantic differences between the results, if any"),
273-
ResponseSchema(name="explanation", description="Detailed explanation of the comparison and reasoning")
287+
ResponseSchema(name="are_semantically_equivalent",
288+
description="""Boolean indicating if the
289+
results are semantically equivalent"""),
290+
ResponseSchema(name="differences",
291+
description="""List of semantic differences
292+
between the results, if any"""),
293+
ResponseSchema(name="explanation",
294+
description="""Detailed explanation of the
295+
comparison and reasoning""")
274296
]
275297
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
276298

@@ -285,45 +307,52 @@ def semantic_comparison(self, generated_result: Any, reference_result: Any) -> D
285307
"generated_result": json.dumps(generated_result, indent=2),
286308
"reference_result": json.dumps(reference_result_dict, indent=2)
287309
})
288-
310+
289311
def syntax_check(self, code):
312+
"""
313+
syntax checker
314+
"""
290315
try:
291316
ast.parse(code)
292317
return True, "Syntax is correct."
293318
except SyntaxError as e:
294319
return False, f"Syntax error: {str(e)}"
295320

296321
def create_sandbox_and_execute(self, function_code):
297-
# Create a sandbox environment
322+
"""
323+
Create a sandbox environment
324+
"""
298325
sandbox_globals = {
299326
'BeautifulSoup': BeautifulSoup,
300327
're': re,
301328
'__builtins__': __builtins__,
302329
}
303-
330+
304331
old_stdout = sys.stdout
305332
sys.stdout = StringIO()
306-
333+
307334
try:
308335
exec(function_code, sandbox_globals)
309-
336+
310337
extract_data = sandbox_globals.get('extract_data')
311-
338+
312339
if not extract_data:
313340
raise NameError("Function 'extract_data' not found in the generated code.")
314-
315-
result = extract_data(self.raw_html)
316-
341+
342+
result = extract_data(self.raw_html)
317343
return True, result
318344
except Exception as e:
319345
return False, f"Error during execution: {str(e)}"
320346
finally:
321347
sys.stdout = old_stdout
322-
348+
323349
def validate_dict(self, data: dict, schema):
350+
"""
351+
validate_dict method
352+
"""
324353
try:
325354
validate(instance=data, schema=schema)
326355
return True, None
327356
except ValidationError as e:
328357
errors = e.errors()
329-
return False, errors
358+
return False, errors

scrapegraphai/nodes/prompt_refiner_node.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,7 @@
44
from typing import List, Optional
55
from langchain.prompts import PromptTemplate
66
from langchain_core.output_parsers import StrOutputParser
7-
from langchain_core.runnables import RunnableParallel
8-
from langchain_core.utils.pydantic import is_basemodel_subclass
9-
from langchain_openai import ChatOpenAI, AzureChatOpenAI
10-
from langchain_mistralai import ChatMistralAI
117
from langchain_community.chat_models import ChatOllama
12-
from tqdm import tqdm
138
from .base_node import BaseNode
149
from ..utils import transform_schema
1510
from ..prompts import (
@@ -61,7 +56,7 @@ def __init__(
6156
)
6257

6358
self.additional_info = node_config.get("additional_info")
64-
59+
6560
self.output_schema = node_config.get("schema")
6661

6762
def execute(self, state: dict) -> dict:
@@ -85,7 +80,7 @@ def execute(self, state: dict) -> dict:
8580
user_prompt = state['user_prompt']
8681

8782
self.simplefied_schema = transform_schema(self.output_schema.schema())
88-
83+
8984
if self.additional_info is not None:
9085
prompt = PromptTemplate(
9186
template=TEMPLATE_REFINER_WITH_CONTEXT,

0 commit comments

Comments
 (0)