Skip to content

Commit 26de5dd

Browse files
authored
Merge branch 'pre/beta' into ligthweigthing_library
2 parents 986c8a1 + f7ba1f3 commit 26de5dd

20 files changed

+40
-104
lines changed

scrapegraphai/graphs/abstract_graph.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def handle_model(model_name, provider, token_key, default_token=8192):
149149
"ollama", "oneapi", "nvidia", "groq", "google_vertexai",
150150
"bedrock", "mistralai", "hugging_face", "deepseek", "ernie", "fireworks"]
151151

152+
152153
if llm_params["model"].split("/")[0] not in known_models and llm_params["model"].split("-")[0] not in known_models:
153154
raise ValueError(f"Model '{llm_params['model']}' is not supported")
154155

scrapegraphai/nodes/generate_answer_csv_node.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from tqdm import tqdm
1010
from ..utils.logging import get_logger
1111
from .base_node import BaseNode
12-
from ..prompts.generate_answer_node_csv_prompts import TEMPLATE_CHUKS_CSV, TEMPLATE_NO_CHUKS_CSV, TEMPLATE_MERGE_CSV
12+
from ..prompts.generate_answer_node_csv_prompts import (TEMPLATE_CHUKS_CSV,
13+
TEMPLATE_NO_CHUKS_CSV, TEMPLATE_MERGE_CSV)
1314

1415
class GenerateAnswerCSVNode(BaseNode):
1516
"""
@@ -95,22 +96,22 @@ def execute(self, state):
9596
else:
9697
output_parser = JsonOutputParser()
9798

98-
TEMPLATE_NO_CHUKS_CSV_prompt = TEMPLATE_NO_CHUKS_CSV
99-
TEMPLATE_CHUKS_CSV_prompt = TEMPLATE_CHUKS_CSV
100-
TEMPLATE_MERGE_CSV_prompt = TEMPLATE_MERGE_CSV
99+
TEMPLATE_NO_CHUKS_CSV_PROMPT = TEMPLATE_NO_CHUKS_CSV
100+
TEMPLATE_CHUKS_CSV_PROMPT = TEMPLATE_CHUKS_CSV
101+
TEMPLATE_MERGE_CSV_PROMPT = TEMPLATE_MERGE_CSV
101102

102103
if self.additional_info is not None:
103-
TEMPLATE_NO_CHUKS_CSV_prompt = self.additional_info + TEMPLATE_NO_CHUKS_CSV
104-
TEMPLATE_CHUKS_CSV_prompt = self.additional_info + TEMPLATE_CHUKS_CSV
105-
TEMPLATE_MERGE_CSV_prompt = self.additional_info + TEMPLATE_MERGE_CSV
104+
TEMPLATE_NO_CHUKS_CSV_PROMPT = self.additional_info + TEMPLATE_NO_CHUKS_CSV
105+
TEMPLATE_CHUKS_CSV_PROMPT = self.additional_info + TEMPLATE_CHUKS_CSV
106+
TEMPLATE_MERGE_CSV_PROMPT = self.additional_info + TEMPLATE_MERGE_CSV
106107

107108
format_instructions = output_parser.get_format_instructions()
108109

109110
chains_dict = {}
110111

111112
if len(doc) == 1:
112113
prompt = PromptTemplate(
113-
template=TEMPLATE_NO_CHUKS_CSV_prompt,
114+
template=TEMPLATE_NO_CHUKS_CSV_PROMPT,
114115
input_variables=["question"],
115116
partial_variables={
116117
"context": doc,
@@ -127,7 +128,7 @@ def execute(self, state):
127128
tqdm(doc, desc="Processing chunks", disable=not self.verbose)
128129
):
129130
prompt = PromptTemplate(
130-
template=TEMPLATE_CHUKS_CSV_prompt,
131+
template=TEMPLATE_CHUKS_CSV_PROMPT,
131132
input_variables=["question"],
132133
partial_variables={
133134
"context": chunk,
@@ -144,7 +145,7 @@ def execute(self, state):
144145
batch_results = async_runner.invoke({"question": user_prompt})
145146

146147
merge_prompt = PromptTemplate(
147-
template = TEMPLATE_MERGE_CSV_prompt,
148+
template = TEMPLATE_MERGE_CSV_PROMPT,
148149
input_variables=["context", "question"],
149150
partial_variables={"format_instructions": format_instructions},
150151
)
@@ -153,4 +154,4 @@ def execute(self, state):
153154
answer = merge_chain.invoke({"context": batch_results, "question": user_prompt})
154155

155156
state.update({self.output[0]: answer})
156-
return state
157+
return state

scrapegraphai/nodes/generate_scraper_node.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,8 @@ def execute(self, state: dict) -> dict:
6767

6868
self.logger.info(f"--- Executing {self.node_name} Node ---")
6969

70-
# Interpret input keys based on the provided input expression
7170
input_keys = self.get_input_keys(state)
7271

73-
# Fetching data from the state based on the input keys
7472
input_data = [state[key] for key in input_keys]
7573

7674
user_prompt = input_data[0]

scrapegraphai/nodes/get_probable_tags_node.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,8 @@ def execute(self, state: dict) -> dict:
5858

5959
self.logger.info(f"--- Executing {self.node_name} Node ---")
6060

61-
# Interpret input keys based on the provided input expression
6261
input_keys = self.get_input_keys(state)
6362

64-
# Fetching data from the state based on the input keys
6563
input_data = [state[key] for key in input_keys]
6664

6765
user_prompt = input_data[0]
@@ -88,10 +86,8 @@ def execute(self, state: dict) -> dict:
8886
},
8987
)
9088

91-
# Execute the chain to get probable tags
9289
tag_answer = tag_prompt | self.llm_model | output_parser
9390
probable_tags = tag_answer.invoke({"question": user_prompt})
9491

95-
# Update the dictionary with probable tags
9692
state.update({self.output[0]: probable_tags})
9793
return state

scrapegraphai/nodes/graph_iterator_node.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ async def _async_execute(self, state: dict, batchsize: int) -> dict:
103103
if graph_instance is None:
104104
raise ValueError("graph instance is required for concurrent execution")
105105

106-
# Assign depth level to the graph
107106
if "graph_depth" in graph_instance.config:
108107
graph_instance.config["graph_depth"] += 1
109108
else:
@@ -113,14 +112,12 @@ async def _async_execute(self, state: dict, batchsize: int) -> dict:
113112

114113
participants = []
115114

116-
# semaphore to limit the number of concurrent tasks
117115
semaphore = asyncio.Semaphore(batchsize)
118116

119117
async def _async_run(graph):
120118
async with semaphore:
121119
return await asyncio.to_thread(graph.run)
122120

123-
# creates a deepcopy of the graph instance for each endpoint
124121
for url in urls:
125122
instance = copy.copy(graph_instance)
126123
instance.source = url

scrapegraphai/nodes/merge_answers_node.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,21 +56,17 @@ def execute(self, state: dict) -> dict:
5656

5757
self.logger.info(f"--- Executing {self.node_name} Node ---")
5858

59-
# Interpret input keys based on the provided input expression
6059
input_keys = self.get_input_keys(state)
6160

62-
# Fetching data from the state based on the input keys
6361
input_data = [state[key] for key in input_keys]
6462

6563
user_prompt = input_data[0]
6664
answers = input_data[1]
6765

68-
# merge the answers in one string
6966
answers_str = ""
7067
for i, answer in enumerate(answers):
7168
answers_str += f"CONTENT WEBSITE {i+1}: {answer}\n"
7269

73-
# Initialize the output parser
7470
if self.node_config.get("schema", None) is not None:
7571
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
7672
else:
@@ -90,6 +86,5 @@ def execute(self, state: dict) -> dict:
9086
merge_chain = prompt_template | self.llm_model | output_parser
9187
answer = merge_chain.invoke({"user_prompt": user_prompt})
9288

93-
# Update the state with the generated answer
9489
state.update({self.output[0]: answer})
9590
return state

scrapegraphai/nodes/parse_node.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,11 @@ def execute(self, state: dict) -> dict:
5959

6060
self.logger.info(f"--- Executing {self.node_name} Node ---")
6161

62-
# Interpret input keys based on the provided input expression
6362
input_keys = self.get_input_keys(state)
6463

65-
# Fetching data from the state based on the input keys
6664
input_data = [state[key] for key in input_keys]
67-
# Parse the document
6865
docs_transformed = input_data[0]
66+
6967
if self.parse_html:
7068
docs_transformed = Html2TextTransformer().transform_documents(input_data[0])
7169
docs_transformed = docs_transformed[0]
@@ -77,7 +75,6 @@ def execute(self, state: dict) -> dict:
7775
else:
7876
docs_transformed = docs_transformed[0]
7977

80-
# Adapt the chunk size, leaving room for the reply, the prompt and the schema
8178
chunk_size = self.node_config.get("chunk_size", 4096)
8279
chunk_size = min(chunk_size - 500, int(chunk_size * 0.9))
8380

scrapegraphai/nodes/rag_node.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,8 @@ def execute(self, state: dict) -> dict:
8080

8181
self.logger.info(f"--- Executing {self.node_name} Node ---")
8282

83-
# Interpret input keys based on the provided input expression
8483
input_keys = self.get_input_keys(state)
8584

86-
# Fetching data from the state based on the input keys
8785
input_data = [state[key] for key in input_keys]
8886

8987
user_prompt = input_data[0]
@@ -102,7 +100,6 @@ def execute(self, state: dict) -> dict:
102100

103101
self.logger.info("--- (updated chunks metadata) ---")
104102

105-
# check if embedder_model is provided, if not use llm_model
106103
if self.embedder_model is not None:
107104
embeddings = self.embedder_model
108105
elif 'embeddings' in self.node_config:
@@ -144,23 +141,17 @@ def execute(self, state: dict) -> dict:
144141
pipeline_compressor = DocumentCompressorPipeline(
145142
transformers=[redundant_filter, relevant_filter]
146143
)
147-
# redundant + relevant filter compressor
148144
compression_retriever = ContextualCompressionRetriever(
149145
base_compressor=pipeline_compressor, base_retriever=retriever
150146
)
151147

152-
# relevant filter compressor only
153-
# compression_retriever = ContextualCompressionRetriever(
154-
# base_compressor=relevant_filter, base_retriever=retriever
155-
# )
156-
157148
compressed_docs = compression_retriever.invoke(user_prompt)
158149

159150
self.logger.info("--- (tokens compressed and vector stored) ---")
160151

161152
state.update({self.output[0]: compressed_docs})
162153
return state
163-
154+
164155

165156
def _create_default_embedder(self, llm_config=None) -> object:
166157
"""
@@ -223,7 +214,6 @@ def _create_embedder(self, embedder_config: dict) -> object:
223214
embedder_params = {**embedder_config}
224215
if "model_instance" in embedder_config:
225216
return embedder_params["model_instance"]
226-
# Instantiate the embedding model based on the model name
227217
if "openai" in embedder_params["model"]:
228218
return OpenAIEmbeddings(api_key=embedder_params["api_key"])
229219
if "azure" in embedder_params["model"]:

scrapegraphai/nodes/robots_node.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,8 @@ def execute(self, state: dict) -> dict:
7575

7676
self.logger.info(f"--- Executing {self.node_name} Node ---")
7777

78-
# Interpret input keys based on the provided input expression
7978
input_keys = self.get_input_keys(state)
8079

81-
# Fetching data from the state based on the input keys
8280
input_data = [state[key] for key in input_keys]
8381

8482
source = input_data[0]

scrapegraphai/nodes/search_internet_node.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def execute(self, state: dict) -> dict:
6767

6868
input_keys = self.get_input_keys(state)
6969

70-
# Fetching data from the state based on the input keys
7170
input_data = [state[key] for key in input_keys]
7271

7372
user_prompt = input_data[0]
@@ -79,10 +78,8 @@ def execute(self, state: dict) -> dict:
7978
input_variables=["user_prompt"],
8079
)
8180

82-
# Execute the chain to get the search query
8381
search_answer = search_prompt | self.llm_model | output_parser
84-
85-
# Ollama: Use no json format when creating the search query
82+
8683
if isinstance(self.llm_model, ChatOllama) and self.llm_model.format == 'json':
8784
self.llm_model.format = None
8885
search_query = search_answer.invoke({"user_prompt": user_prompt})[0]
@@ -96,9 +93,7 @@ def execute(self, state: dict) -> dict:
9693
search_engine=self.search_engine)
9794

9895
if len(answer) == 0:
99-
# raise an exception if no answer is found
10096
raise ValueError("Zero results found for the search query.")
10197

102-
# Update the state with the generated answer
10398
state.update({self.output[0]: answer})
10499
return state

0 commit comments

Comments
 (0)