88from langchain_community .embeddings import HuggingFaceHubEmbeddings , OllamaEmbeddings
99from langchain_google_genai import GoogleGenerativeAIEmbeddings
1010from ..helpers import models_tokens
11- from ..models import AzureOpenAI , Bedrock , Gemini , Groq , HuggingFace , Ollama , OpenAI , Anthropic
11+ from ..models import AzureOpenAI , Bedrock , Gemini , Groq , HuggingFace , Ollama , OpenAI , Anthropic , DeepSeek
1212from langchain_google_genai .embeddings import GoogleGenerativeAIEmbeddings
1313
14+ from ..helpers import models_tokens
15+ from ..models import AzureOpenAI , Bedrock , Gemini , Groq , HuggingFace , Ollama , OpenAI , Anthropic , DeepSeek
16+
17+
1418class AbstractGraph (ABC ):
1519 """
1620 Scaffolding class for creating a graph representation and executing it.
1721
1822 prompt (str): The prompt for the graph.
1923 source (str): The source of the graph.
2024 config (dict): Configuration parameters for the graph.
25+ schema (str): The schema for the graph output.
2126 llm_model: An instance of a language model client, configured for generating answers.
2227 embedder_model: An instance of an embedding model client,
2328 configured for generating embeddings.
@@ -28,6 +33,7 @@ class AbstractGraph(ABC):
2833 prompt (str): The prompt for the graph.
2934 config (dict): Configuration parameters for the graph.
3035 source (str, optional): The source of the graph.
36+ schema (str, optional): The schema for the graph output.
3137
3238 Example:
3339 >>> class MyGraph(AbstractGraph):
@@ -39,34 +45,42 @@ class AbstractGraph(ABC):
3945 >>> result = my_graph.run()
4046 """
4147
42- def __init__ (self , prompt : str , config : dict , source : Optional [str ] = None ):
48+ def __init__ (self , prompt : str , config : dict , source : Optional [str ] = None , schema : Optional [ str ] = None ):
4349
4450 self .prompt = prompt
4551 self .source = source
4652 self .config = config
53+ self .schema = schema
4754 self .llm_model = self ._create_llm (config ["llm" ], chat = True )
4855 self .embedder_model = self ._create_default_embedder (llm_config = config ["llm" ]
4956 ) if "embeddings" not in config else self ._create_embedder (
5057 config ["embeddings" ])
58+ self .verbose = False if config is None else config .get (
59+ "verbose" , False )
60+ self .headless = True if config is None else config .get (
61+ "headless" , True )
62+ self .loader_kwargs = config .get ("loader_kwargs" , {})
5163
5264 # Create the graph
5365 self .graph = self ._create_graph ()
5466 self .final_state = None
5567 self .execution_info = None
5668
5769 # Set common configuration parameters
58-
5970 self .verbose = False if config is None else config .get (
6071 "verbose" , False )
6172 self .headless = True if config is None else config .get (
6273 "headless" , True )
6374 self .loader_kwargs = config .get ("loader_kwargs" , {})
6475
65- common_params = {"headless" : self .headless ,
66-
67- "loader_kwargs" : self .loader_kwargs ,
68- "llm_model" : self .llm_model ,
69- "embedder_model" : self .embedder_model }
76+ common_params = {
77+ "headless" : self .headless ,
78+ "verbose" : self .verbose ,
79+ "loader_kwargs" : self .loader_kwargs ,
80+ "llm_model" : self .llm_model ,
81+ "embedder_model" : self .embedder_model
82+ }
83+
7084 self .set_common_params (common_params , overwrite = False )
7185
7286 def set_common_params (self , params : dict , overwrite = False ):
@@ -79,7 +93,7 @@ def set_common_params(self, params: dict, overwrite=False):
7993
8094 for node in self .graph .nodes :
8195 node .update_config (params , overwrite )
82-
96+
8397 def _set_model_token (self , llm ):
8498
8599 if 'Azure' in str (type (llm )):
@@ -157,7 +171,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
157171 raise KeyError ("Model not supported" ) from exc
158172 return Anthropic (llm_params )
159173 elif "ollama" in llm_params ["model" ]:
160- llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
174+ llm_params ["model" ] = llm_params ["model" ].split ("ollama /" )[- 1 ]
161175
162176 # allow user to set model_tokens in config
163177 try :
@@ -231,6 +245,8 @@ def _create_default_embedder(self, llm_config=None) -> object:
231245 model = "models/embedding-001" )
232246 if isinstance (self .llm_model , OpenAI ):
233247 return OpenAIEmbeddings (api_key = self .llm_model .openai_api_key )
248+ elif isinstance (self .llm_model , DeepSeek ):
249+ return OpenAIEmbeddings (api_key = self .llm_model .openai_api_key )
234250 elif isinstance (self .llm_model , AzureOpenAIEmbeddings ):
235251 return self .llm_model
236252 elif isinstance (self .llm_model , AzureOpenAI ):
@@ -271,7 +287,7 @@ def _create_embedder(self, embedder_config: dict) -> object:
271287 elif "azure" in embedder_config ["model" ]:
272288 return AzureOpenAIEmbeddings ()
273289 elif "ollama" in embedder_config ["model" ]:
274- embedder_config ["model" ] = embedder_config ["model" ].split ("/" )[- 1 ]
290+ embedder_config ["model" ] = embedder_config ["model" ].split ("ollama /" )[- 1 ]
275291 try :
276292 models_tokens ["ollama" ][embedder_config ["model" ]]
277293 except KeyError as exc :
@@ -297,6 +313,10 @@ def _create_embedder(self, embedder_config: dict) -> object:
297313 except KeyError as exc :
298314 raise KeyError ("Model not supported" ) from exc
299315 return BedrockEmbeddings (client = client , model_id = embedder_config ["model" ])
316+ else :
317+ raise ValueError (
318+ "Model provided by the configuration not supported" )
319+
300320 def get_state (self , key = None ) -> dict :
301321 """""
302322 Get the final state of the graph.
@@ -334,4 +354,4 @@ def run(self) -> str:
334354 """
335355 Abstract method to execute the graph and return the result.
336356 """
337- pass
357+ pass
0 commit comments