|
5 | 5 | from abc import ABC, abstractmethod |
6 | 6 | from typing import Optional |
7 | 7 |
|
8 | | -from ..models import OpenAI, Gemini, Ollama, AzureOpenAI, HuggingFace, Groq, Bedrock |
| 8 | +from langchain_aws.embeddings.bedrock import BedrockEmbeddings |
| 9 | +from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings |
| 10 | +from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings |
| 11 | + |
9 | 12 | from ..helpers import models_tokens |
| 13 | +from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI |
10 | 14 |
|
11 | 15 |
|
12 | 16 | class AbstractGraph(ABC): |
@@ -43,7 +47,8 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None): |
43 | 47 | self.source = source |
44 | 48 | self.config = config |
45 | 49 | self.llm_model = self._create_llm(config["llm"], chat=True) |
46 | | - self.embedder_model = self.llm_model if "embeddings" not in config else self._create_llm( |
| 50 | + self.embedder_model = self._create_default_embedder( |
| 51 | + ) if "embeddings" not in config else self._create_embedder( |
47 | 52 | config["embeddings"]) |
48 | 53 |
|
49 | 54 | # Set common configuration parameters |
@@ -172,6 +177,85 @@ def _create_llm(self, llm_config: dict, chat=False) -> object: |
172 | 177 | else: |
173 | 178 | raise ValueError( |
174 | 179 | "Model provided by the configuration not supported") |
| 180 | + |
| 181 | + def _create_default_embedder(self) -> object: |
| 182 | + """ |
| 183 | + Create an embedding model instance based on the chosen llm model. |
| 184 | +
|
| 185 | + Returns: |
| 186 | + object: An instance of the embedding model client. |
| 187 | +
|
| 188 | + Raises: |
| 189 | + ValueError: If the model is not supported. |
| 190 | + """ |
| 191 | + |
| 192 | + if isinstance(self.llm_model, OpenAI): |
| 193 | + return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key) |
| 194 | + elif isinstance(self.llm_model, AzureOpenAIEmbeddings): |
| 195 | + return self.llm_model |
| 196 | + elif isinstance(self.llm_model, AzureOpenAI): |
| 197 | + return AzureOpenAIEmbeddings() |
| 198 | + elif isinstance(self.llm_model, Ollama): |
| 199 | + # unwrap the kwargs from the model whihc is a dict |
| 200 | + params = self.llm_model._lc_kwargs |
| 201 | + # remove streaming and temperature |
| 202 | + params.pop("streaming", None) |
| 203 | + params.pop("temperature", None) |
| 204 | + |
| 205 | + return OllamaEmbeddings(**params) |
| 206 | + elif isinstance(self.llm_model, HuggingFace): |
| 207 | + return HuggingFaceHubEmbeddings(model=self.llm_model.model) |
| 208 | + elif isinstance(self.llm_model, Bedrock): |
| 209 | + return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id) |
| 210 | + else: |
| 211 | + raise ValueError("Embedding Model missing or not supported") |
| 212 | + |
| 213 | + def _create_embedder(self, embedder_config: dict) -> object: |
| 214 | + """ |
| 215 | + Create an embedding model instance based on the configuration provided. |
| 216 | +
|
| 217 | + Args: |
| 218 | + embedder_config (dict): Configuration parameters for the embedding model. |
| 219 | +
|
| 220 | + Returns: |
| 221 | + object: An instance of the embedding model client. |
| 222 | +
|
| 223 | + Raises: |
| 224 | + KeyError: If the model is not supported. |
| 225 | + """ |
| 226 | + |
| 227 | + # Instantiate the embedding model based on the model name |
| 228 | + if "openai" in embedder_config["model"]: |
| 229 | + return OpenAIEmbeddings(api_key=embedder_config["api_key"]) |
| 230 | + |
| 231 | + elif "azure" in embedder_config["model"]: |
| 232 | + return AzureOpenAIEmbeddings() |
| 233 | + |
| 234 | + elif "ollama" in embedder_config["model"]: |
| 235 | + embedder_config["model"] = embedder_config["model"].split("/")[-1] |
| 236 | + try: |
| 237 | + models_tokens["ollama"][embedder_config["model"]] |
| 238 | + except KeyError: |
| 239 | + raise KeyError("Model not supported") |
| 240 | + return OllamaEmbeddings(**embedder_config) |
| 241 | + |
| 242 | + elif "hugging_face" in embedder_config["model"]: |
| 243 | + try: |
| 244 | + models_tokens["hugging_face"][embedder_config["model"]] |
| 245 | + except KeyError: |
| 246 | + raise KeyError("Model not supported") |
| 247 | + return HuggingFaceHubEmbeddings(model=embedder_config["model"]) |
| 248 | + |
| 249 | + elif "bedrock" in embedder_config["model"]: |
| 250 | + embedder_config["model"] = embedder_config["model"].split("/")[-1] |
| 251 | + try: |
| 252 | + models_tokens["bedrock"][embedder_config["model"]] |
| 253 | + except KeyError: |
| 254 | + raise KeyError("Model not supported") |
| 255 | + return BedrockEmbeddings(client=None, model_id=embedder_config["model"]) |
| 256 | + else: |
| 257 | + raise ValueError( |
| 258 | + "Model provided by the configuration not supported") |
175 | 259 |
|
176 | 260 | def get_state(self, key=None) -> dict: |
177 | 261 | """"" |
|
0 commit comments