|
1 | 1 | """ |
2 | 2 | AbstractGraph Module |
3 | 3 | """ |
4 | | - |
5 | 4 | from abc import ABC, abstractmethod |
6 | 5 | from typing import Optional |
7 | | - |
8 | | -from langchain_aws.embeddings.bedrock import BedrockEmbeddings |
9 | | -from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings |
10 | 6 | from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings |
11 | | - |
| 7 | +from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings |
12 | 8 | from ..helpers import models_tokens |
13 | 9 | from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Claude |
| 10 | +from langchain_aws.embeddings.bedrock import BedrockEmbeddings |
| 11 | +from langchain_google_genai import GoogleGenerativeAIEmbeddings |
14 | 12 |
|
15 | 13 |
|
16 | 14 | class AbstractGraph(ABC): |
@@ -69,7 +67,7 @@ def _set_model_token(self, llm): |
69 | 67 | self.model_token = models_tokens["azure"][llm.model_name] |
70 | 68 | except KeyError: |
71 | 69 | raise KeyError("Model not supported") |
72 | | - |
| 70 | + |
73 | 71 | elif 'HuggingFaceEndpoint' in str(type(llm)): |
74 | 72 | if 'mistral' in llm.repo_id: |
75 | 73 | try: |
@@ -229,29 +227,30 @@ def _create_embedder(self, embedder_config: dict) -> object: |
229 | 227 |
|
230 | 228 | if 'model_instance' in embedder_config: |
231 | 229 | return embedder_config['model_instance'] |
232 | | - |
233 | 230 | # Instantiate the embedding model based on the model name |
234 | 231 | if "openai" in embedder_config["model"]: |
235 | 232 | return OpenAIEmbeddings(api_key=embedder_config["api_key"]) |
236 | | - |
237 | 233 | elif "azure" in embedder_config["model"]: |
238 | 234 | return AzureOpenAIEmbeddings() |
239 | | - |
240 | 235 | elif "ollama" in embedder_config["model"]: |
241 | 236 | embedder_config["model"] = embedder_config["model"].split("/")[-1] |
242 | 237 | try: |
243 | 238 | models_tokens["ollama"][embedder_config["model"]] |
244 | 239 | except KeyError as exc: |
245 | 240 | raise KeyError("Model not supported") from exc |
246 | 241 | return OllamaEmbeddings(**embedder_config) |
247 | | - |
248 | 242 | elif "hugging_face" in embedder_config["model"]: |
249 | 243 | try: |
250 | 244 | models_tokens["hugging_face"][embedder_config["model"]] |
251 | 245 | except KeyError as exc: |
252 | 246 | raise KeyError("Model not supported")from exc |
253 | 247 | return HuggingFaceHubEmbeddings(model=embedder_config["model"]) |
254 | | - |
| 248 | + elif "gemini" in embedder_config["model"]: |
| 249 | + try: |
| 250 | + models_tokens["gemini"][embedder_config["model"]] |
| 251 | + except KeyError as exc: |
| 252 | + raise KeyError("Model not supported")from exc |
| 253 | + return GoogleGenerativeAIEmbeddings(model=embedder_config["model"]) |
255 | 254 | elif "bedrock" in embedder_config["model"]: |
256 | 255 | embedder_config["model"] = embedder_config["model"].split("/")[-1] |
257 | 256 | try: |
|
0 commit comments