|
1 | 1 | """ |
2 | 2 | AbstractGraph Module |
3 | 3 | """ |
4 | | - |
5 | 4 | from abc import ABC, abstractmethod |
6 | 5 | from typing import Optional |
7 | | - |
8 | | -from langchain_aws.embeddings.bedrock import BedrockEmbeddings |
9 | | -from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings |
10 | 6 | from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings |
11 | | - |
| 7 | +from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings |
12 | 8 | from ..helpers import models_tokens |
13 | 9 | from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Claude |
| 10 | +from langchain_aws.embeddings.bedrock import BedrockEmbeddings |
| 11 | +from langchain_google_genai import GoogleGenerativeAIEmbeddings |
14 | 12 |
|
15 | 13 |
|
16 | 14 | class AbstractGraph(ABC): |
@@ -86,7 +84,7 @@ def _set_model_token(self, llm): |
86 | 84 | self.model_token = models_tokens["azure"][llm.model_name] |
87 | 85 | except KeyError: |
88 | 86 | raise KeyError("Model not supported") |
89 | | - |
| 87 | + |
90 | 88 | elif 'HuggingFaceEndpoint' in str(type(llm)): |
91 | 89 | if 'mistral' in llm.repo_id: |
92 | 90 | try: |
@@ -246,29 +244,30 @@ def _create_embedder(self, embedder_config: dict) -> object: |
246 | 244 |
|
247 | 245 | if 'model_instance' in embedder_config: |
248 | 246 | return embedder_config['model_instance'] |
249 | | - |
250 | 247 | # Instantiate the embedding model based on the model name |
251 | 248 | if "openai" in embedder_config["model"]: |
252 | 249 | return OpenAIEmbeddings(api_key=embedder_config["api_key"]) |
253 | | - |
254 | 250 | elif "azure" in embedder_config["model"]: |
255 | 251 | return AzureOpenAIEmbeddings() |
256 | | - |
257 | 252 | elif "ollama" in embedder_config["model"]: |
258 | 253 | embedder_config["model"] = embedder_config["model"].split("/")[-1] |
259 | 254 | try: |
260 | 255 | models_tokens["ollama"][embedder_config["model"]] |
261 | 256 | except KeyError as exc: |
262 | 257 | raise KeyError("Model not supported") from exc |
263 | 258 | return OllamaEmbeddings(**embedder_config) |
264 | | - |
265 | 259 | elif "hugging_face" in embedder_config["model"]: |
266 | 260 | try: |
267 | 261 | models_tokens["hugging_face"][embedder_config["model"]] |
268 | 262 | except KeyError as exc: |
269 | 263 | raise KeyError("Model not supported")from exc |
270 | 264 | return HuggingFaceHubEmbeddings(model=embedder_config["model"]) |
271 | | - |
| 265 | + elif "gemini" in embedder_config["model"]: |
| 266 | + try: |
| 267 | + models_tokens["gemini"][embedder_config["model"]] |
| 268 | + except KeyError as exc: |
| 269 | + raise KeyError("Model not supported")from exc |
| 270 | + return GoogleGenerativeAIEmbeddings(model=embedder_config["model"]) |
272 | 271 | elif "bedrock" in embedder_config["model"]: |
273 | 272 | embedder_config["model"] = embedder_config["model"].split("/")[-1] |
274 | 273 | try: |
|
0 commit comments