Skip to content

Commit c105c26

Browse files
committed
Update abstract_graph.py
1 parent 8cece1d commit c105c26

File tree

1 file changed

+79
-73
lines changed

1 file changed

+79
-73
lines changed

scrapegraphai/graphs/abstract_graph.py

Lines changed: 79 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -146,78 +146,84 @@ def handle_model(model_name, provider, token_key, default_token=8192):
146146
with warnings.catch_warnings():
147147
warnings.simplefilter("ignore")
148148
return init_chat_model(**llm_params)
149-
150-
if "fireworks" in llm_params["model"]:
151-
model_name = "/".join(llm_params["model"].split("/")[1:])
152-
token_key = llm_params["model"].split("/")[-1]
153-
return handle_model(model_name, "fireworks", token_key)
154-
155-
elif "gemini" in llm_params["model"]:
156-
model_name = llm_params["model"].split("/")[-1]
157-
return handle_model(model_name, "google_genai", model_name)
158-
159-
elif llm_params["model"].startswith("claude"):
160-
model_name = llm_params["model"].split("/")[-1]
161-
return handle_model(model_name, "anthropic", model_name)
162-
163-
elif llm_params["model"].startswith("vertexai"):
164-
return handle_model(llm_params["model"], "google_vertexai", llm_params["model"])
165-
elif "gpt-" in llm_params["model"]:
166-
return handle_model(llm_params["model"], "openai", llm_params["model"])
167-
168-
elif "ollama" in llm_params["model"]:
169-
model_name = llm_params["model"].split("ollama/")[-1]
170-
token_key = model_name if "model_tokens" not in llm_params else llm_params["model_tokens"]
171-
return handle_model(model_name, "ollama", token_key)
172-
173-
elif "claude-3-" in llm_params["model"]:
174-
return handle_model(llm_params["model"], "anthropic", "claude3")
175-
176-
elif llm_params["model"].startswith("mistral"):
177-
model_name = llm_params["model"].split("/")[-1]
178-
return handle_model(model_name, "mistralai", model_name)
179-
180-
# Instantiate the language model based on the model name (models that do not use the common interface)
181-
elif "deepseek" in llm_params["model"]:
182-
try:
183-
self.model_token = models_tokens["deepseek"][llm_params["model"]]
184-
except KeyError:
185-
print("model not found, using default token size (8192)")
186-
self.model_token = 8192
187-
return DeepSeek(llm_params)
188-
189-
elif "ernie" in llm_params["model"]:
190-
try:
191-
self.model_token = models_tokens["ernie"][llm_params["model"]]
192-
except KeyError:
193-
print("model not found, using default token size (8192)")
194-
self.model_token = 8192
195-
return ErnieBotChat(llm_params)
196-
197-
elif "oneapi" in llm_params["model"]:
198-
199-
# take the model after the last dash
200-
llm_params["model"] = llm_params["model"].split("/")[-1]
201-
try:
202-
self.model_token = models_tokens["oneapi"][llm_params["model"]]
203-
except KeyError as exc:
204-
raise KeyError("Model not supported") from exc
205-
return OneApi(llm_params)
206-
207-
elif "nvidia" in llm_params["model"]:
208-
209-
try:
210-
self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]]
211-
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
212-
except KeyError as exc:
213-
raise KeyError("Model not supported") from exc
214-
return ChatNVIDIA(llm_params)
215-
else:
216-
model_name = llm_params["model"].split("/")[-1]
217-
return handle_model(model_name, llm_params["model"], model_name)
218-
219-
raise ValueError("Model provided by the configuration not supported")
220-
149+
150+
known_models = ["azure", "fireworks", "gemini", "claude", "vertexai", "hugging_face", "groq", "gpt-", "ollama", "claude-3-", "bedrock", "mistral", "ernie", "oneapi", "nvidia"]
151+
152+
if llm_params["model"] not in known_models:
153+
raise ValueError(f"Model '{llm_params['model']}' is not supported")
154+
155+
try:
156+
if "fireworks" in llm_params["model"]:
157+
model_name = "/".join(llm_params["model"].split("/")[1:])
158+
token_key = llm_params["model"].split("/")[-1]
159+
return handle_model(model_name, "fireworks", token_key)
160+
161+
elif "gemini" in llm_params["model"]:
162+
model_name = llm_params["model"].split("/")[-1]
163+
return handle_model(model_name, "google_genai", model_name)
164+
165+
elif llm_params["model"].startswith("claude"):
166+
model_name = llm_params["model"].split("/")[-1]
167+
return handle_model(model_name, "anthropic", model_name)
168+
169+
elif llm_params["model"].startswith("vertexai"):
170+
return handle_model(llm_params["model"], "google_vertexai", llm_params["model"])
171+
172+
elif "gpt-" in llm_params["model"]:
173+
return handle_model(llm_params["model"], "openai", llm_params["model"])
174+
175+
elif "ollama" in llm_params["model"]:
176+
model_name = llm_params["model"].split("ollama/")[-1]
177+
token_key = model_name if "model_tokens" not in llm_params else llm_params["model_tokens"]
178+
return handle_model(model_name, "ollama", token_key)
179+
180+
elif "claude-3-" in llm_params["model"]:
181+
return handle_model(llm_params["model"], "anthropic", "claude3")
182+
183+
elif llm_params["model"].startswith("mistral"):
184+
model_name = llm_params["model"].split("/")[-1]
185+
return handle_model(model_name, "mistralai", model_name)
186+
187+
# Instantiate the language model based on the model name (models that do not use the common interface)
188+
elif "deepseek" in llm_params["model"]:
189+
try:
190+
self.model_token = models_tokens["deepseek"][llm_params["model"]]
191+
except KeyError:
192+
print("model not found, using default token size (8192)")
193+
self.model_token = 8192
194+
return DeepSeek(llm_params)
195+
196+
elif "ernie" in llm_params["model"]:
197+
try:
198+
self.model_token = models_tokens["ernie"][llm_params["model"]]
199+
except KeyError:
200+
print("model not found, using default token size (8192)")
201+
self.model_token = 8192
202+
return ErnieBotChat(llm_params)
203+
204+
elif "oneapi" in llm_params["model"]:
205+
# take the model after the last dash
206+
llm_params["model"] = llm_params["model"].split("/")[-1]
207+
try:
208+
self.model_token = models_tokens["oneapi"][llm_params["model"]]
209+
except KeyError:
210+
raise KeyError("Model not supported")
211+
return OneApi(llm_params)
212+
213+
elif "nvidia" in llm_params["model"]:
214+
try:
215+
self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]]
216+
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
217+
except KeyError:
218+
raise KeyError("Model not supported")
219+
return ChatNVIDIA(llm_params)
220+
221+
else:
222+
model_name = llm_params["model"].split("/")[-1]
223+
return handle_model(model_name, llm_params["model"], model_name)
224+
225+
except KeyError as e:
226+
print(f"Model not supported: {e}")
221227

222228
def get_state(self, key=None) -> dict:
223229
""" ""
@@ -264,4 +270,4 @@ def _create_graph(self):
264270
def run(self) -> str:
265271
"""
266272
Abstract method to execute the graph and return the result.
267-
"""
273+
"""

0 commit comments

Comments
 (0)