1111from langchain_community .embeddings import HuggingFaceHubEmbeddings , OllamaEmbeddings
1212from langchain_google_genai import GoogleGenerativeAIEmbeddings
1313from langchain_google_genai .embeddings import GoogleGenerativeAIEmbeddings
14+ from langchain_fireworks import FireworksEmbeddings
1415from langchain_openai import AzureOpenAIEmbeddings , OpenAIEmbeddings
1516
1617from ..helpers import models_tokens
2324 HuggingFace ,
2425 Ollama ,
2526 OpenAI ,
26- OneApi
27+ OneApi ,
28+ Fireworks
2729)
2830from ..models .ernie import Ernie
2931from ..utils .logging import set_verbosity_debug , set_verbosity_warning , set_verbosity_info
@@ -102,7 +104,7 @@ def __init__(self, prompt: str, config: dict,
102104 "embedder_model" : self .embedder_model ,
103105 "cache_path" : self .cache_path ,
104106 }
105-
107+
106108 self .set_common_params (common_params , overwrite = True )
107109
108110 # set burr config
@@ -125,7 +127,7 @@ def set_common_params(self, params: dict, overwrite=False):
125127
126128 for node in self .graph .nodes :
127129 node .update_config (params , overwrite )
128-
130+
129131 def _create_llm (self , llm_config : dict , chat = False ) -> object :
130132 """
131133 Create a large language model instance based on the configuration provided.
@@ -160,8 +162,15 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
160162 try :
161163 self .model_token = models_tokens ["oneapi" ][llm_params ["model" ]]
162164 except KeyError as exc :
163- raise KeyError ("Model Model not supported" ) from exc
165+ raise KeyError ("Model not supported" ) from exc
164166 return OneApi (llm_params )
167+ elif "fireworks" in llm_params ["model" ]:
168+ try :
169+ self .model_token = models_tokens ["fireworks" ][llm_params ["model" ].split ("/" )[- 1 ]]
170+ llm_params ["model" ] = "/" .join (llm_params ["model" ].split ("/" )[1 :])
171+ except KeyError as exc :
172+ raise KeyError ("Model not supported" ) from exc
173+ return Fireworks (llm_params )
165174 elif "azure" in llm_params ["model" ]:
166175 # take the model after the last dash
167176 llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
@@ -172,12 +181,14 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
172181 return AzureOpenAI (llm_params )
173182
174183 elif "gemini" in llm_params ["model" ]:
184+ llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
175185 try :
176186 self .model_token = models_tokens ["gemini" ][llm_params ["model" ]]
177187 except KeyError as exc :
178188 raise KeyError ("Model not supported" ) from exc
179189 return Gemini (llm_params )
180190 elif llm_params ["model" ].startswith ("claude" ):
191+ llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
181192 try :
182193 self .model_token = models_tokens ["claude" ][llm_params ["model" ]]
183194 except KeyError as exc :
@@ -203,6 +214,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
203214
204215 return Ollama (llm_params )
205216 elif "hugging_face" in llm_params ["model" ]:
217+ llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
206218 try :
207219 self .model_token = models_tokens ["hugging_face" ][llm_params ["model" ]]
208220 except KeyError :
@@ -277,12 +289,13 @@ def _create_default_embedder(self, llm_config=None) -> object:
277289 if isinstance (self .llm_model , OpenAI ):
278290 return OpenAIEmbeddings (api_key = self .llm_model .openai_api_key , base_url = self .llm_model .openai_api_base )
279291 elif isinstance (self .llm_model , DeepSeek ):
280- return OpenAIEmbeddings (api_key = self .llm_model .openai_api_key )
281-
292+ return OpenAIEmbeddings (api_key = self .llm_model .openai_api_key )
282293 elif isinstance (self .llm_model , AzureOpenAIEmbeddings ):
283294 return self .llm_model
284295 elif isinstance (self .llm_model , AzureOpenAI ):
285296 return AzureOpenAIEmbeddings ()
297+ elif isinstance (self .llm_model , Fireworks ):
298+ return FireworksEmbeddings (model = self .llm_model .model_name )
286299 elif isinstance (self .llm_model , Ollama ):
287300 # unwrap the kwargs from the model whihc is a dict
288301 params = self .llm_model ._lc_kwargs
@@ -333,6 +346,13 @@ def _create_embedder(self, embedder_config: dict) -> object:
333346 except KeyError as exc :
334347 raise KeyError ("Model not supported" ) from exc
335348 return HuggingFaceHubEmbeddings (model = embedder_params ["model" ])
349+ elif "fireworks" in embedder_params ["model" ]:
350+ embedder_params ["model" ] = "/" .join (embedder_params ["model" ].split ("/" )[1 :])
351+ try :
352+ models_tokens ["fireworks" ][embedder_params ["model" ]]
353+ except KeyError as exc :
354+ raise KeyError ("Model not supported" ) from exc
355+ return FireworksEmbeddings (model = embedder_params ["model" ])
336356 elif "gemini" in embedder_params ["model" ]:
337357 try :
338358 models_tokens ["gemini" ][embedder_params ["model" ]]
0 commit comments