Skip to content

Commit fd59f28

Browse files
committed
Bug fixes
1 parent d830d13 commit fd59f28

File tree

2 files changed

+4
-31
lines changed

2 files changed

+4
-31
lines changed

scrapegraphai/graphs/abstract_graph.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,6 @@ def _create_default_embedder(self) -> object:
188188
Raises:
189189
ValueError: If the model is not supported.
190190
"""
191-
192191
if isinstance(self.llm_model, OpenAI):
193192
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
194193
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
@@ -223,6 +222,9 @@ def _create_embedder(self, embedder_config: dict) -> object:
223222
Raises:
224223
KeyError: If the model is not supported.
225224
"""
225+
226+
if 'model_instance' in embedder_config:
227+
return embedder_config['model_instance']
226228

227229
# Instantiate the embedding model based on the model name
228230
if "openai" in embedder_config["model"]:

scrapegraphai/nodes/rag_node.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@
88
from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline
99
from langchain_community.document_transformers import EmbeddingsRedundantFilter
1010
from langchain_community.vectorstores import FAISS
11-
from langchain_community.embeddings import OllamaEmbeddings
12-
from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings
13-
from langchain_community.embeddings.huggingface import HuggingFaceInferenceAPIEmbeddings
1411

1512
from .base_node import BaseNode
1613

@@ -86,33 +83,7 @@ def execute(self, state: dict) -> dict:
8683
print("--- (updated chunks metadata) ---")
8784

8885
# check if embedder_model is provided, if not use llm_model
89-
embedding_model = self.embedder_model if self.embedder_model else self.llm_model
90-
91-
if isinstance(embedding_model, OpenAI):
92-
embeddings = OpenAIEmbeddings(
93-
api_key=embedding_model.openai_api_key)
94-
elif isinstance(embedding_model, AzureOpenAIEmbeddings):
95-
embeddings = embedding_model
96-
elif isinstance(embedding_model, HuggingFaceInferenceAPIEmbeddings):
97-
embeddings = embedding_model
98-
99-
elif isinstance(embedding_model, AzureOpenAI):
100-
embeddings = AzureOpenAIEmbeddings()
101-
elif isinstance(embedding_model, Ollama):
102-
# unwrap the kwargs from the model whihc is a dict
103-
params = embedding_model._lc_kwargs
104-
# remove streaming and temperature
105-
params.pop("streaming", None)
106-
params.pop("temperature", None)
107-
108-
embeddings = OllamaEmbeddings(**params)
109-
elif isinstance(embedding_model, HuggingFace):
110-
embeddings = HuggingFaceHubEmbeddings(model=embedding_model.model)
111-
elif isinstance(embedding_model, Bedrock):
112-
embeddings = BedrockEmbeddings(
113-
client=None, model_id=embedding_model.model_id)
114-
else:
115-
raise ValueError("Embedding Model missing or not supported")
86+
self.embedder_model = self.embedder_model if self.embedder_model else self.llm_model
11687
embeddings = self.embedder_model
11788

11889
retriever = FAISS.from_documents(

0 commit comments

Comments
 (0)