1010from langchain_openai import AzureOpenAIEmbeddings , OpenAIEmbeddings
1111
1212from ..helpers import models_tokens
13- from ..models import AzureOpenAI , Bedrock , Gemini , Groq , HuggingFace , Ollama , OpenAI
13+ from ..models import AzureOpenAI , Bedrock , Gemini , Groq , HuggingFace , Ollama , OpenAI , Anthropic
1414
1515
1616class AbstractGraph (ABC ):
@@ -47,8 +47,8 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
4747 self .source = source
4848 self .config = config
4949 self .llm_model = self ._create_llm (config ["llm" ], chat = True )
50- self .embedder_model = self ._create_default_embedder (
51- ) if "embeddings" not in config else self ._create_embedder (
50+ self .embedder_model = self ._create_default_embedder (
51+ ) if "embeddings" not in config else self ._create_embedder (
5252 config ["embeddings" ])
5353
5454 # Set common configuration parameters
@@ -61,23 +61,21 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
6161 self .final_state = None
6262 self .execution_info = None
6363
64-
6564 def _set_model_token (self , llm ):
6665
6766 if 'Azure' in str (type (llm )):
6867 try :
6968 self .model_token = models_tokens ["azure" ][llm .model_name ]
7069 except KeyError :
7170 raise KeyError ("Model not supported" )
72-
71+
7372 elif 'HuggingFaceEndpoint' in str (type (llm )):
7473 if 'mistral' in llm .repo_id :
7574 try :
7675 self .model_token = models_tokens ['mistral' ][llm .repo_id ]
7776 except KeyError :
7877 raise KeyError ("Model not supported" )
7978
80-
8179 def _create_llm (self , llm_config : dict , chat = False ) -> object :
8280 """
8381 Create a large language model instance based on the configuration provided.
@@ -103,7 +101,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
103101 if chat :
104102 self ._set_model_token (llm_params ['model_instance' ])
105103 return llm_params ['model_instance' ]
106-
104+
107105 # Instantiate the language model based on the model name
108106 if "gpt-" in llm_params ["model" ]:
109107 try :
@@ -174,10 +172,13 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
174172 "temperature" : llm_params ["temperature" ],
175173 }
176174 })
175+ elif "claude-3-" in llm_params ["model" ]:
176+ self .model_token = models_tokens ["claude" ]["claude3" ]
177+ return Anthropic (llm_params )
177178 else :
178179 raise ValueError (
179180 "Model provided by the configuration not supported" )
180-
181+
181182 def _create_default_embedder (self ) -> object :
182183 """
183184 Create an embedding model instance based on the chosen llm model.
@@ -208,7 +209,7 @@ def _create_default_embedder(self) -> object:
208209 return BedrockEmbeddings (client = None , model_id = self .llm_model .model_id )
209210 else :
210211 raise ValueError ("Embedding Model missing or not supported" )
211-
212+
212213 def _create_embedder (self , embedder_config : dict ) -> object :
213214 """
214215 Create an embedding model instance based on the configuration provided.
@@ -225,7 +226,7 @@ def _create_embedder(self, embedder_config: dict) -> object:
225226
226227 if 'model_instance' in embedder_config :
227228 return embedder_config ['model_instance' ]
228-
229+
229230 # Instantiate the embedding model based on the model name
230231 if "openai" in embedder_config ["model" ]:
231232 return OpenAIEmbeddings (api_key = embedder_config ["api_key" ])
@@ -240,14 +241,14 @@ def _create_embedder(self, embedder_config: dict) -> object:
240241 except KeyError :
241242 raise KeyError ("Model not supported" )
242243 return OllamaEmbeddings (** embedder_config )
243-
244+
244245 elif "hugging_face" in embedder_config ["model" ]:
245246 try :
246247 models_tokens ["hugging_face" ][embedder_config ["model" ]]
247248 except KeyError :
248249 raise KeyError ("Model not supported" )
249250 return HuggingFaceHubEmbeddings (model = embedder_config ["model" ])
250-
251+
251252 elif "bedrock" in embedder_config ["model" ]:
252253 embedder_config ["model" ] = embedder_config ["model" ].split ("/" )[- 1 ]
253254 try :
@@ -257,7 +258,7 @@ def _create_embedder(self, embedder_config: dict) -> object:
257258 return BedrockEmbeddings (client = None , model_id = embedder_config ["model" ])
258259 else :
259260 raise ValueError (
260- "Model provided by the configuration not supported" )
261+ "Model provided by the configuration not supported" )
261262
262263 def get_state (self , key = None ) -> dict :
263264 """""
@@ -281,7 +282,7 @@ def get_execution_info(self):
281282 Returns:
282283 dict: The execution information of the graph.
283284 """
284-
285+
285286 return self .execution_info
286287
287288 @abstractmethod
@@ -297,4 +298,3 @@ def run(self) -> str:
297298 Abstract method to execute the graph and return the result.
298299 """
299300 pass
300-
0 commit comments