1010from langchain_openai import AzureOpenAIEmbeddings , OpenAIEmbeddings
1111
1212from ..helpers import models_tokens
13- from ..models import AzureOpenAI , Bedrock , Gemini , Groq , HuggingFace , Ollama , OpenAI
13+ from ..models import AzureOpenAI , Bedrock , Gemini , Groq , HuggingFace , Ollama , OpenAI , Claude
1414
1515
1616class AbstractGraph (ABC ):
@@ -22,7 +22,8 @@ class AbstractGraph(ABC):
2222 source (str): The source of the graph.
2323 config (dict): Configuration parameters for the graph.
2424 llm_model: An instance of a language model client, configured for generating answers.
25- embedder_model: An instance of an embedding model client, configured for generating embeddings.
25+ embedder_model: An instance of an embedding model client,
26+ configured for generating embeddings.
2627 verbose (bool): A flag indicating whether to show print statements during execution.
2728 headless (bool): A flag indicating whether to run the graph in headless mode.
2829
@@ -47,8 +48,8 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
4748 self .source = source
4849 self .config = config
4950 self .llm_model = self ._create_llm (config ["llm" ], chat = True )
50- self .embedder_model = self ._create_default_embedder (
51- ) if "embeddings" not in config else self ._create_embedder (
51+ self .embedder_model = self ._create_default_embedder (
52+ ) if "embeddings" not in config else self ._create_embedder (
5253 config ["embeddings" ])
5354
5455 # Set common configuration parameters
@@ -61,15 +62,13 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
6162 self .final_state = None
6263 self .execution_info = None
6364
64-
6565 def _set_model_token (self , llm ):
6666
6767 if 'Azure' in str (type (llm )):
6868 try :
6969 self .model_token = models_tokens ["azure" ][llm .model_name ]
70- except KeyError :
71- raise KeyError ("Model not supported" )
72-
70+ except KeyError as exc :
71+ raise KeyError ("Model not supported" ) from exc
7372
7473 def _create_llm (self , llm_config : dict , chat = False ) -> object :
7574 """
@@ -96,31 +95,36 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
9695 if chat :
9796 self ._set_model_token (llm_params ['model_instance' ])
9897 return llm_params ['model_instance' ]
99-
98+
10099 # Instantiate the language model based on the model name
101100 if "gpt-" in llm_params ["model" ]:
102101 try :
103102 self .model_token = models_tokens ["openai" ][llm_params ["model" ]]
104- except KeyError :
105- raise KeyError ("Model not supported" )
103+ except KeyError as exc :
104+ raise KeyError ("Model not supported" ) from exc
106105 return OpenAI (llm_params )
107106
108107 elif "azure" in llm_params ["model" ]:
109108 # take the model after the last dash
110109 llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
111110 try :
112111 self .model_token = models_tokens ["azure" ][llm_params ["model" ]]
113- except KeyError :
114- raise KeyError ("Model not supported" )
112+ except KeyError as exc :
113+ raise KeyError ("Model not supported" ) from exc
115114 return AzureOpenAI (llm_params )
116115
117116 elif "gemini" in llm_params ["model" ]:
118117 try :
119118 self .model_token = models_tokens ["gemini" ][llm_params ["model" ]]
120- except KeyError :
121- raise KeyError ("Model not supported" )
119+ except KeyError as exc :
120+ raise KeyError ("Model not supported" ) from exc
122121 return Gemini (llm_params )
123-
122+ elif "claude" in llm_params ["model" ]:
123+ try :
124+ self .model_token = models_tokens ["claude" ][llm_params ["model" ]]
125+ except KeyError as exc :
126+ raise KeyError ("Model not supported" ) from exc
127+ return Claude (llm_params )
124128 elif "ollama" in llm_params ["model" ]:
125129 llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
126130
@@ -131,8 +135,8 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
131135 elif llm_params ["model" ] in models_tokens ["ollama" ]:
132136 try :
133137 self .model_token = models_tokens ["ollama" ][llm_params ["model" ]]
134- except KeyError :
135- raise KeyError ("Model not supported" )
138+ except KeyError as exc :
139+ raise KeyError ("Model not supported" ) from exc
136140 else :
137141 self .model_token = 8192
138142 except AttributeError :
@@ -142,25 +146,25 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
142146 elif "hugging_face" in llm_params ["model" ]:
143147 try :
144148 self .model_token = models_tokens ["hugging_face" ][llm_params ["model" ]]
145- except KeyError :
146- raise KeyError ("Model not supported" )
149+ except KeyError as exc :
150+ raise KeyError ("Model not supported" ) from exc
147151 return HuggingFace (llm_params )
148152 elif "groq" in llm_params ["model" ]:
149153 llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
150154
151155 try :
152156 self .model_token = models_tokens ["groq" ][llm_params ["model" ]]
153- except KeyError :
154- raise KeyError ("Model not supported" )
157+ except KeyError as exc :
158+ raise KeyError ("Model not supported" ) from exc
155159 return Groq (llm_params )
156160 elif "bedrock" in llm_params ["model" ]:
157161 llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
158162 model_id = llm_params ["model" ]
159163
160164 try :
161165 self .model_token = models_tokens ["bedrock" ][llm_params ["model" ]]
162- except KeyError :
163- raise KeyError ("Model not supported" )
166+ except KeyError as exc :
167+ raise KeyError ("Model not supported" ) from exc
164168 return Bedrock ({
165169 "model_id" : model_id ,
166170 "model_kwargs" : {
@@ -170,7 +174,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
170174 else :
171175 raise ValueError (
172176 "Model provided by the configuration not supported" )
173-
177+
174178 def _create_default_embedder (self ) -> object :
175179 """
176180 Create an embedding model instance based on the chosen llm model.
@@ -202,7 +206,7 @@ def _create_default_embedder(self) -> object:
202206 return BedrockEmbeddings (client = None , model_id = self .llm_model .model_id )
203207 else :
204208 raise ValueError ("Embedding Model missing or not supported" )
205-
209+
206210 def _create_embedder (self , embedder_config : dict ) -> object :
207211 """
208212 Create an embedding model instance based on the configuration provided.
@@ -216,7 +220,7 @@ def _create_embedder(self, embedder_config: dict) -> object:
216220 Raises:
217221 KeyError: If the model is not supported.
218222 """
219-
223+
220224 # Instantiate the embedding model based on the model name
221225 if "openai" in embedder_config ["model" ]:
222226 return OpenAIEmbeddings (api_key = embedder_config ["api_key" ])
@@ -228,27 +232,27 @@ def _create_embedder(self, embedder_config: dict) -> object:
228232 embedder_config ["model" ] = embedder_config ["model" ].split ("/" )[- 1 ]
229233 try :
230234 models_tokens ["ollama" ][embedder_config ["model" ]]
231- except KeyError :
232- raise KeyError ("Model not supported" )
235+ except KeyError as exc :
236+ raise KeyError ("Model not supported" ) from exc
233237 return OllamaEmbeddings (** embedder_config )
234-
238+
235239 elif "hugging_face" in embedder_config ["model" ]:
236240 try :
237241 models_tokens ["hugging_face" ][embedder_config ["model" ]]
238- except KeyError :
239- raise KeyError ("Model not supported" )
242+ except KeyError as exc :
243+ raise KeyError ("Model not supported" )from exc
240244 return HuggingFaceHubEmbeddings (model = embedder_config ["model" ])
241-
245+
242246 elif "bedrock" in embedder_config ["model" ]:
243247 embedder_config ["model" ] = embedder_config ["model" ].split ("/" )[- 1 ]
244248 try :
245249 models_tokens ["bedrock" ][embedder_config ["model" ]]
246- except KeyError :
247- raise KeyError ("Model not supported" )
250+ except KeyError as exc :
251+ raise KeyError ("Model not supported" ) from exc
248252 return BedrockEmbeddings (client = None , model_id = embedder_config ["model" ])
249253 else :
250254 raise ValueError (
251- "Model provided by the configuration not supported" )
255+ "Model provided by the configuration not supported" )
252256
253257 def get_state (self , key = None ) -> dict :
254258 """""
@@ -272,7 +276,7 @@ def get_execution_info(self):
272276 Returns:
273277 dict: The execution information of the graph.
274278 """
275-
279+
276280 return self .execution_info
277281
278282 @abstractmethod
@@ -288,4 +292,3 @@ def run(self) -> str:
288292 Abstract method to execute the graph and return the result.
289293 """
290294 pass
291-
0 commit comments