@@ -47,8 +47,8 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
4747 self .source = source
4848 self .config = config
4949 self .llm_model = self ._create_llm (config ["llm" ], chat = True )
50- self .embedder_model = self ._create_default_embedder (
51- ) if "embeddings" not in config else self ._create_embedder (
50+ self .embedder_model = self ._create_default_embedder (
51+ ) if "embeddings" not in config else self ._create_embedder (
5252 config ["embeddings" ])
5353
5454 # Set common configuration parameters
@@ -61,21 +61,23 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
6161 self .final_state = None
6262 self .execution_info = None
6363
64+
6465 def _set_model_token (self , llm ):
6566
6667 if 'Azure' in str (type (llm )):
6768 try :
6869 self .model_token = models_tokens ["azure" ][llm .model_name ]
6970 except KeyError :
7071 raise KeyError ("Model not supported" )
71-
72+
7273 elif 'HuggingFaceEndpoint' in str (type (llm )):
7374 if 'mistral' in llm .repo_id :
7475 try :
7576 self .model_token = models_tokens ['mistral' ][llm .repo_id ]
7677 except KeyError :
7778 raise KeyError ("Model not supported" )
7879
80+
7981 def _create_llm (self , llm_config : dict , chat = False ) -> object :
8082 """
8183 Create a large language model instance based on the configuration provided.
@@ -101,7 +103,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
101103 if chat :
102104 self ._set_model_token (llm_params ['model_instance' ])
103105 return llm_params ['model_instance' ]
104-
106+
105107 # Instantiate the language model based on the model name
106108 if "gpt-" in llm_params ["model" ]:
107109 try :
@@ -178,7 +180,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
178180 else :
179181 raise ValueError (
180182 "Model provided by the configuration not supported" )
181-
183+
182184 def _create_default_embedder (self ) -> object :
183185 """
184186 Create an embedding model instance based on the chosen llm model.
@@ -209,7 +211,7 @@ def _create_default_embedder(self) -> object:
209211 return BedrockEmbeddings (client = None , model_id = self .llm_model .model_id )
210212 else :
211213 raise ValueError ("Embedding Model missing or not supported" )
212-
214+
213215 def _create_embedder (self , embedder_config : dict ) -> object :
214216 """
215217 Create an embedding model instance based on the configuration provided.
@@ -226,7 +228,7 @@ def _create_embedder(self, embedder_config: dict) -> object:
226228
227229 if 'model_instance' in embedder_config :
228230 return embedder_config ['model_instance' ]
229-
231+
230232 # Instantiate the embedding model based on the model name
231233 if "openai" in embedder_config ["model" ]:
232234 return OpenAIEmbeddings (api_key = embedder_config ["api_key" ])
@@ -241,14 +243,14 @@ def _create_embedder(self, embedder_config: dict) -> object:
241243 except KeyError :
242244 raise KeyError ("Model not supported" )
243245 return OllamaEmbeddings (** embedder_config )
244-
246+
245247 elif "hugging_face" in embedder_config ["model" ]:
246248 try :
247249 models_tokens ["hugging_face" ][embedder_config ["model" ]]
248250 except KeyError :
249251 raise KeyError ("Model not supported" )
250252 return HuggingFaceHubEmbeddings (model = embedder_config ["model" ])
251-
253+
252254 elif "bedrock" in embedder_config ["model" ]:
253255 embedder_config ["model" ] = embedder_config ["model" ].split ("/" )[- 1 ]
254256 try :
@@ -258,7 +260,7 @@ def _create_embedder(self, embedder_config: dict) -> object:
258260 return BedrockEmbeddings (client = None , model_id = embedder_config ["model" ])
259261 else :
260262 raise ValueError (
261- "Model provided by the configuration not supported" )
263+ "Model provided by the configuration not supported" )
262264
263265 def get_state (self , key = None ) -> dict :
264266 """""
@@ -282,7 +284,7 @@ def get_execution_info(self):
282284 Returns:
283285 dict: The execution information of the graph.
284286 """
285-
287+
286288 return self .execution_info
287289
288290 @abstractmethod
@@ -298,3 +300,4 @@ def run(self) -> str:
298300 Abstract method to execute the graph and return the result.
299301 """
300302 pass
303+
0 commit comments