55from .base_graph import BaseGraph
66from ..nodes import (
77 SearchInternetNode ,
8- FetchNode ,
9- ParseNode ,
10- RAGNode ,
11- GenerateAnswerNode
8+ GraphIteratorNode ,
9+ MergeAnswersNode
1210)
1311from .abstract_graph import AbstractGraph
12+ from .smart_scraper_graph import SmartScraperGraph
1413
1514
1615class SearchGraph (AbstractGraph ):
@@ -38,6 +37,11 @@ class SearchGraph(AbstractGraph):
3837 >>> result = search_graph.run()
3938 """
4039
40+ def __init__ (self , prompt : str , config : dict ):
41+
42+ self .max_results = config .get ("max_results" , 3 )
43+ super ().__init__ (prompt , config )
44+
4145 def _create_graph (self ) -> BaseGraph :
4246 """
4347 Creates the graph of nodes representing the workflow for web scraping and searching.
@@ -46,53 +50,53 @@ def _create_graph(self) -> BaseGraph:
4650 BaseGraph: A graph instance representing the web scraping and searching workflow.
4751 """
4852
53+ # ************************************************
54+ # Create a SmartScraperGraph instance
55+ # ************************************************
56+
57+ smart_scraper_instance = SmartScraperGraph (
58+ prompt = "" ,
59+ source = "" ,
60+ config = self .config
61+ )
62+
63+ # ************************************************
64+ # Define the graph nodes
65+ # ************************************************
66+
4967 search_internet_node = SearchInternetNode (
5068 input = "user_prompt" ,
51- output = ["url" ],
52- node_config = {
53- "llm_model" : self .llm_model
54- }
55- )
56- fetch_node = FetchNode (
57- input = "url | local_dir" ,
58- output = ["doc" ]
59- )
60- parse_node = ParseNode (
61- input = "doc" ,
62- output = ["parsed_doc" ],
69+ output = ["urls" ],
6370 node_config = {
64- "chunk_size" : self .model_token
71+ "llm_model" : self .llm_model ,
72+ "max_results" : self .max_results
6573 }
6674 )
67- rag_node = RAGNode (
68- input = "user_prompt & (parsed_doc | doc) " ,
69- output = ["relevant_chunks " ],
75+ graph_iterator_node = GraphIteratorNode (
76+ input = "user_prompt & urls " ,
77+ output = ["results " ],
7078 node_config = {
71- "llm_model" : self .llm_model ,
72- "embedder_model" : self .embedder_model
79+ "graph_instance" : smart_scraper_instance ,
7380 }
7481 )
75- generate_answer_node = GenerateAnswerNode (
76- input = "user_prompt & (relevant_chunks | parsed_doc | doc)" ,
82+
83+ merge_answers_node = MergeAnswersNode (
84+ input = "user_prompt & results" ,
7785 output = ["answer" ],
7886 node_config = {
79- "llm_model" : self .llm_model
87+ "llm_model" : self .llm_model ,
8088 }
8189 )
8290
8391 return BaseGraph (
8492 nodes = [
8593 search_internet_node ,
86- fetch_node ,
87- parse_node ,
88- rag_node ,
89- generate_answer_node ,
94+ graph_iterator_node ,
95+ merge_answers_node
9096 ],
9197 edges = [
92- (search_internet_node , fetch_node ),
93- (fetch_node , parse_node ),
94- (parse_node , rag_node ),
95- (rag_node , generate_answer_node )
98+ (search_internet_node , graph_iterator_node ),
99+ (graph_iterator_node , merge_answers_node )
96100 ],
97101 entry_point = search_internet_node
98102 )
0 commit comments