diff --git a/examples/groq/smart_scraper_multi_cond_groq.py b/examples/extras/conditional_usage.py similarity index 74% rename from examples/groq/smart_scraper_multi_cond_groq.py rename to examples/extras/conditional_usage.py index 7e81cfd2..d3152bed 100644 --- a/examples/groq/smart_scraper_multi_cond_groq.py +++ b/examples/extras/conditional_usage.py @@ -5,7 +5,7 @@ import os import json from dotenv import load_dotenv -from scrapegraphai.graphs import SmartScraperMultiCondGraph +from scrapegraphai.graphs import SmartScraperMultiGraph load_dotenv() @@ -13,22 +13,21 @@ # Define the configuration for the graph # ************************************************ -groq_key = os.getenv("GROQ_APIKEY") - graph_config = { "llm": { - "model": "groq/gemma-7b-it", - "api_key": groq_key, - "temperature": 0 + "api_key": os.getenv("OPENAI_API_KEY"), + "model": "openai/gpt-4o", }, - "headless": False + + "verbose": True, + "headless": False, } # ******************************************************* # Create the SmartScraperMultiCondGraph instance and run it # ******************************************************* -multiple_search_graph = SmartScraperMultiCondGraph( +multiple_search_graph = SmartScraperMultiGraph( prompt="Who is Marco Perini?", source=[ "https://perinim.github.io/", diff --git a/scrapegraphai/graphs/__init__.py b/scrapegraphai/graphs/__init__.py index 3415af3e..b5ffcc47 100644 --- a/scrapegraphai/graphs/__init__.py +++ b/scrapegraphai/graphs/__init__.py @@ -26,5 +26,4 @@ from .screenshot_scraper_graph import ScreenshotScraperGraph from .smart_scraper_multi_concat_graph import SmartScraperMultiConcatGraph from .code_generator_graph import CodeGeneratorGraph -from .smart_scraper_multi_cond_graph import SmartScraperMultiCondGraph from .depth_search_graph import DepthSearchGraph diff --git a/scrapegraphai/graphs/markdown_scraper_multi_graph.py b/scrapegraphai/graphs/markdown_scraper_multi_graph.py index 1857f872..b6a13111 100644 --- a/scrapegraphai/graphs/markdown_scraper_multi_graph.py +++ b/scrapegraphai/graphs/markdown_scraper_multi_graph.py @@ -41,7 +41,7 @@ class MDScraperMultiGraph(AbstractGraph): >>> result = search_graph.run() """ - def __init__(self, prompt: str, source: List[str], + def __init__(self, prompt: str, source: List[str], config: dict, schema: Optional[BaseModel] = None): self.copy_config = safe_deepcopy(config) self.copy_schema = deepcopy(schema) diff --git a/scrapegraphai/graphs/smart_scraper_graph.py b/scrapegraphai/graphs/smart_scraper_graph.py index 60407624..4a2a0a6a 100644 --- a/scrapegraphai/graphs/smart_scraper_graph.py +++ b/scrapegraphai/graphs/smart_scraper_graph.py @@ -10,7 +10,8 @@ FetchNode, ParseNode, ReasoningNode, - GenerateAnswerNode + GenerateAnswerNode, + ConditionalNode ) class SmartScraperGraph(AbstractGraph): diff --git a/scrapegraphai/graphs/smart_scraper_multi_concat_graph.py b/scrapegraphai/graphs/smart_scraper_multi_concat_graph.py index ce879317..cc1a7003 100644 --- a/scrapegraphai/graphs/smart_scraper_multi_concat_graph.py +++ b/scrapegraphai/graphs/smart_scraper_multi_concat_graph.py @@ -1,5 +1,5 @@ -""" -SmartScraperMultiGraph Module +""" +SmartScraperMultiCondGraph Module with ConditionalNode """ from copy import deepcopy from typing import List, Optional @@ -9,15 +9,16 @@ from .smart_scraper_graph import SmartScraperGraph from ..nodes import ( GraphIteratorNode, - ConcatAnswersNode + MergeAnswersNode, + ConcatAnswersNode, + ConditionalNode ) from ..utils.copy import safe_deepcopy -class SmartScraperMultiConcatGraph(AbstractGraph): +class SmartScraperMultiCondGraph(AbstractGraph): """ - SmartScraperMultiGraph is a scraping pipeline that scrapes a + SmartScraperMultiConditionalGraph is a scraping pipeline that scrapes a list of URLs and generates answers to a given prompt. - It only requires a user prompt and a list of URLs. Attributes: prompt (str): The user prompt to search the internet. @@ -34,24 +35,26 @@ class SmartScraperMultiConcatGraph(AbstractGraph): schema (Optional[BaseModel]): The schema for the graph output. Example: - >>> search_graph = SmartScraperMultiConcatGraph( + >>> search_graph = MultipleSearchGraph( ... "What is Chioggia famous for?", ... {"llm": {"model": "openai/gpt-3.5-turbo"}} ... ) >>> result = search_graph.run() """ - - def __init__(self, prompt: str, source: List[str], + + def __init__(self, prompt: str, source: List[str], config: dict, schema: Optional[BaseModel] = None): - self.copy_config = safe_deepcopy(config) + self.max_results = config.get("max_results", 3) + self.copy_config = safe_deepcopy(config) self.copy_schema = deepcopy(schema) super().__init__(prompt, config, source, schema) def _create_graph(self) -> BaseGraph: """ - Creates the graph of nodes representing the workflow for web scraping and searching. + Creates the graph of nodes representing the workflow for web scraping and searching, + including a ConditionalNode to decide between merging or concatenating the results. Returns: BaseGraph: A graph instance representing the web scraping and searching workflow. @@ -65,20 +68,49 @@ def _create_graph(self) -> BaseGraph: "scraper_config": self.copy_config, }, schema=self.copy_schema, + node_name="GraphIteratorNode" + ) + + conditional_node = ConditionalNode( + input="results", + output=["results"], + node_name="ConditionalNode", + node_config={ + 'key_name': 'results', + 'condition': 'len(results) > 2' + } + ) + + merge_answers_node = MergeAnswersNode( + input="user_prompt & results", + output=["answer"], + node_config={ + "llm_model": self.llm_model, + "schema": self.copy_schema + }, + node_name="MergeAnswersNode" ) - concat_answers_node = ConcatAnswersNode( + concat_node = ConcatAnswersNode( input="results", - output=["answer"] + output=["answer"], + node_config={}, + node_name="ConcatNode" ) return BaseGraph( nodes=[ graph_iterator_node, - concat_answers_node, + conditional_node, + merge_answers_node, + concat_node, ], edges=[ - (graph_iterator_node, concat_answers_node), + (graph_iterator_node, conditional_node), + # True node (len(results) > 2) + (conditional_node, merge_answers_node), + # False node (len(results) <= 2) + (conditional_node, concat_node) ], entry_point=graph_iterator_node, graph_name=self.__class__.__name__ diff --git a/scrapegraphai/graphs/smart_scraper_multi_cond_graph.py b/scrapegraphai/graphs/smart_scraper_multi_cond_graph.py deleted file mode 100644 index 278e3905..00000000 --- a/scrapegraphai/graphs/smart_scraper_multi_cond_graph.py +++ /dev/null @@ -1,130 +0,0 @@ -""" -SmartScraperMultiCondGraph Module with ConditionalNode -""" -from copy import deepcopy -from typing import List, Optional -from pydantic import BaseModel -from .base_graph import BaseGraph -from .abstract_graph import AbstractGraph -from .smart_scraper_graph import SmartScraperGraph -from ..nodes import ( - GraphIteratorNode, - MergeAnswersNode, - ConcatAnswersNode, - ConditionalNode -) -from ..utils.copy import safe_deepcopy - -class SmartScraperMultiCondGraph(AbstractGraph): - """ - SmartScraperMultiConditionalGraph is a scraping pipeline that scrapes a - list of URLs and generates answers to a given prompt. - - Attributes: - prompt (str): The user prompt to search the internet. - llm_model (dict): The configuration for the language model. - embedder_model (dict): The configuration for the embedder model. - headless (bool): A flag to run the browser in headless mode. - verbose (bool): A flag to display the execution information. - model_token (int): The token limit for the language model. - - Args: - prompt (str): The user prompt to search the internet. - source (List[str]): The source of the graph. - config (dict): Configuration parameters for the graph. - schema (Optional[BaseModel]): The schema for the graph output. - - Example: - >>> search_graph = MultipleSearchGraph( - ... "What is Chioggia famous for?", - ... {"llm": {"model": "openai/gpt-3.5-turbo"}} - ... ) - >>> result = search_graph.run() - """ - - def __init__(self, prompt: str, source: List[str], - config: dict, schema: Optional[BaseModel] = None): - - self.max_results = config.get("max_results", 3) - self.copy_config = safe_deepcopy(config) - self.copy_schema = deepcopy(schema) - - super().__init__(prompt, config, source, schema) - - def _create_graph(self) -> BaseGraph: - """ - Creates the graph of nodes representing the workflow for web scraping and searching, - including a ConditionalNode to decide between merging or concatenating the results. - - Returns: - BaseGraph: A graph instance representing the web scraping and searching workflow. - """ - - # Node that iterates over the URLs and collects results - graph_iterator_node = GraphIteratorNode( - input="user_prompt & urls", - output=["results"], - node_config={ - "graph_instance": SmartScraperGraph, - "scraper_config": self.copy_config, - }, - schema=self.copy_schema, - node_name="GraphIteratorNode" - ) - - # ConditionalNode to check if len(results) > 2 - conditional_node = ConditionalNode( - input="results", - output=["results"], - node_name="ConditionalNode", - node_config={ - 'key_name': 'results', - 'condition': 'len(results) > 2' - } - ) - - merge_answers_node = MergeAnswersNode( - input="user_prompt & results", - output=["answer"], - node_config={ - "llm_model": self.llm_model, - "schema": self.copy_schema - }, - node_name="MergeAnswersNode" - ) - - concat_node = ConcatAnswersNode( - input="results", - output=["answer"], - node_config={}, - node_name="ConcatNode" - ) - - # Build the graph - return BaseGraph( - nodes=[ - graph_iterator_node, - conditional_node, - merge_answers_node, - concat_node, - ], - edges=[ - (graph_iterator_node, conditional_node), - (conditional_node, merge_answers_node), # True node (len(results) > 2) - (conditional_node, concat_node), # False node (len(results) <= 2) - ], - entry_point=graph_iterator_node, - graph_name=self.__class__.__name__ - ) - - def run(self) -> str: - """ - Executes the web scraping and searching process. - - Returns: - str: The answer to the prompt. - """ - inputs = {"user_prompt": self.prompt, "urls": self.source} - self.final_state, self.execution_info = self.graph.execute(inputs) - - return self.final_state.get("answer", "No answer found.") diff --git a/scrapegraphai/nodes/conditional_node.py b/scrapegraphai/nodes/conditional_node.py index 238d2919..305844a5 100644 --- a/scrapegraphai/nodes/conditional_node.py +++ b/scrapegraphai/nodes/conditional_node.py @@ -38,17 +38,15 @@ def __init__(self, Initializes an empty ConditionalNode. """ super().__init__(node_name, "conditional_node", input, output, 2, node_config) - + try: self.key_name = self.node_config["key_name"] except: raise NotImplementedError("You need to provide key_name inside the node config") - + self.true_node_name = None self.false_node_name = None - self.condition = self.node_config.get("condition", None) - self.eval_instance = EvalWithCompoundTypes() self.eval_instance.functions = {'len': len} @@ -65,21 +63,18 @@ def execute(self, state: dict) -> dict: if self.true_node_name is None or self.false_node_name is None: raise ValueError("ConditionalNode's next nodes are not set properly.") - - # Evaluate the condition + if self.condition: condition_result = self._evaluate_condition(state, self.condition) else: - # Default behavior: check existence and non-emptiness of key_name value = state.get(self.key_name) condition_result = value is not None and value != '' - # Return the appropriate next node name if condition_result: return self.true_node_name else: return self.false_node_name - + def _evaluate_condition(self, state: dict, condition: str) -> bool: """ Parses and evaluates the condition expression against the state. @@ -104,4 +99,4 @@ def _evaluate_condition(self, state: dict, condition: str) -> bool: ) return bool(result) except Exception as e: - raise ValueError(f"Error evaluating condition '{condition}' in {self.node_name}: {e}") \ No newline at end of file + raise ValueError(f"Error evaluating condition '{condition}' in {self.node_name}: {e}")