1010from langchain_openai import AzureOpenAIEmbeddings , OpenAIEmbeddings
1111
1212from ..helpers import models_tokens
13- from ..models import AzureOpenAI , Bedrock , Gemini , Groq , HuggingFace , Ollama , OpenAI
13+ from ..models import AzureOpenAI , Bedrock , Gemini , Groq , HuggingFace , Ollama , OpenAI , Claude
1414
1515
1616class AbstractGraph (ABC ):
@@ -22,7 +22,8 @@ class AbstractGraph(ABC):
2222 source (str): The source of the graph.
2323 config (dict): Configuration parameters for the graph.
2424 llm_model: An instance of a language model client, configured for generating answers.
25- embedder_model: An instance of an embedding model client, configured for generating embeddings.
25+ embedder_model: An instance of an embedding model client,
26+ configured for generating embeddings.
2627 verbose (bool): A flag indicating whether to show print statements during execution.
2728 headless (bool): A flag indicating whether to run the graph in headless mode.
2829
@@ -47,8 +48,8 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
4748 self .source = source
4849 self .config = config
4950 self .llm_model = self ._create_llm (config ["llm" ], chat = True )
50- self .embedder_model = self ._create_default_embedder (
51- ) if "embeddings" not in config else self ._create_embedder (
51+ self .embedder_model = self ._create_default_embedder (
52+ ) if "embeddings" not in config else self ._create_embedder (
5253 config ["embeddings" ])
5354
5455 # Set common configuration parameters
@@ -61,23 +62,21 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
6162 self .final_state = None
6263 self .execution_info = None
6364
64-
6565 def _set_model_token (self , llm ):
6666
6767 if 'Azure' in str (type (llm )):
6868 try :
6969 self .model_token = models_tokens ["azure" ][llm .model_name ]
7070 except KeyError :
7171 raise KeyError ("Model not supported" )
72-
72+
7373 elif 'HuggingFaceEndpoint' in str (type (llm )):
7474 if 'mistral' in llm .repo_id :
7575 try :
7676 self .model_token = models_tokens ['mistral' ][llm .repo_id ]
7777 except KeyError :
7878 raise KeyError ("Model not supported" )
7979
80-
8180 def _create_llm (self , llm_config : dict , chat = False ) -> object :
8281 """
8382 Create a large language model instance based on the configuration provided.
@@ -103,31 +102,36 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
103102 if chat :
104103 self ._set_model_token (llm_params ['model_instance' ])
105104 return llm_params ['model_instance' ]
106-
105+
107106 # Instantiate the language model based on the model name
108107 if "gpt-" in llm_params ["model" ]:
109108 try :
110109 self .model_token = models_tokens ["openai" ][llm_params ["model" ]]
111- except KeyError :
112- raise KeyError ("Model not supported" )
110+ except KeyError as exc :
111+ raise KeyError ("Model not supported" ) from exc
113112 return OpenAI (llm_params )
114113
115114 elif "azure" in llm_params ["model" ]:
116115 # take the model after the last dash
117116 llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
118117 try :
119118 self .model_token = models_tokens ["azure" ][llm_params ["model" ]]
120- except KeyError :
121- raise KeyError ("Model not supported" )
119+ except KeyError as exc :
120+ raise KeyError ("Model not supported" ) from exc
122121 return AzureOpenAI (llm_params )
123122
124123 elif "gemini" in llm_params ["model" ]:
125124 try :
126125 self .model_token = models_tokens ["gemini" ][llm_params ["model" ]]
127- except KeyError :
128- raise KeyError ("Model not supported" )
126+ except KeyError as exc :
127+ raise KeyError ("Model not supported" ) from exc
129128 return Gemini (llm_params )
130-
129+ elif "claude" in llm_params ["model" ]:
130+ try :
131+ self .model_token = models_tokens ["claude" ][llm_params ["model" ]]
132+ except KeyError as exc :
133+ raise KeyError ("Model not supported" ) from exc
134+ return Claude (llm_params )
131135 elif "ollama" in llm_params ["model" ]:
132136 llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
133137
@@ -138,8 +142,8 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
138142 elif llm_params ["model" ] in models_tokens ["ollama" ]:
139143 try :
140144 self .model_token = models_tokens ["ollama" ][llm_params ["model" ]]
141- except KeyError :
142- raise KeyError ("Model not supported" )
145+ except KeyError as exc :
146+ raise KeyError ("Model not supported" ) from exc
143147 else :
144148 self .model_token = 8192
145149 except AttributeError :
@@ -149,25 +153,25 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
149153 elif "hugging_face" in llm_params ["model" ]:
150154 try :
151155 self .model_token = models_tokens ["hugging_face" ][llm_params ["model" ]]
152- except KeyError :
153- raise KeyError ("Model not supported" )
156+ except KeyError as exc :
157+ raise KeyError ("Model not supported" ) from exc
154158 return HuggingFace (llm_params )
155159 elif "groq" in llm_params ["model" ]:
156160 llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
157161
158162 try :
159163 self .model_token = models_tokens ["groq" ][llm_params ["model" ]]
160- except KeyError :
161- raise KeyError ("Model not supported" )
164+ except KeyError as exc :
165+ raise KeyError ("Model not supported" ) from exc
162166 return Groq (llm_params )
163167 elif "bedrock" in llm_params ["model" ]:
164168 llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
165169 model_id = llm_params ["model" ]
166170
167171 try :
168172 self .model_token = models_tokens ["bedrock" ][llm_params ["model" ]]
169- except KeyError :
170- raise KeyError ("Model not supported" )
173+ except KeyError as exc :
174+ raise KeyError ("Model not supported" ) from exc
171175 return Bedrock ({
172176 "model_id" : model_id ,
173177 "model_kwargs" : {
@@ -177,7 +181,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
177181 else :
178182 raise ValueError (
179183 "Model provided by the configuration not supported" )
180-
184+
181185 def _create_default_embedder (self ) -> object :
182186 """
183187 Create an embedding model instance based on the chosen llm model.
@@ -208,7 +212,7 @@ def _create_default_embedder(self) -> object:
208212 return BedrockEmbeddings (client = None , model_id = self .llm_model .model_id )
209213 else :
210214 raise ValueError ("Embedding Model missing or not supported" )
211-
215+
212216 def _create_embedder (self , embedder_config : dict ) -> object :
213217 """
214218 Create an embedding model instance based on the configuration provided.
@@ -237,27 +241,27 @@ def _create_embedder(self, embedder_config: dict) -> object:
237241 embedder_config ["model" ] = embedder_config ["model" ].split ("/" )[- 1 ]
238242 try :
239243 models_tokens ["ollama" ][embedder_config ["model" ]]
240- except KeyError :
241- raise KeyError ("Model not supported" )
244+ except KeyError as exc :
245+ raise KeyError ("Model not supported" ) from exc
242246 return OllamaEmbeddings (** embedder_config )
243-
247+
244248 elif "hugging_face" in embedder_config ["model" ]:
245249 try :
246250 models_tokens ["hugging_face" ][embedder_config ["model" ]]
247- except KeyError :
248- raise KeyError ("Model not supported" )
251+ except KeyError as exc :
252+ raise KeyError ("Model not supported" )from exc
249253 return HuggingFaceHubEmbeddings (model = embedder_config ["model" ])
250-
254+
251255 elif "bedrock" in embedder_config ["model" ]:
252256 embedder_config ["model" ] = embedder_config ["model" ].split ("/" )[- 1 ]
253257 try :
254258 models_tokens ["bedrock" ][embedder_config ["model" ]]
255- except KeyError :
256- raise KeyError ("Model not supported" )
259+ except KeyError as exc :
260+ raise KeyError ("Model not supported" ) from exc
257261 return BedrockEmbeddings (client = None , model_id = embedder_config ["model" ])
258262 else :
259263 raise ValueError (
260- "Model provided by the configuration not supported" )
264+ "Model provided by the configuration not supported" )
261265
262266 def get_state (self , key = None ) -> dict :
263267 """""
@@ -281,7 +285,7 @@ def get_execution_info(self):
281285 Returns:
282286 dict: The execution information of the graph.
283287 """
284-
288+
285289 return self .execution_info
286290
287291 @abstractmethod
@@ -297,4 +301,3 @@ def run(self) -> str:
297301 Abstract method to execute the graph and return the result.
298302 """
299303 pass
300-
0 commit comments