Skip to content

Commit cb406f2

Browse files
authored
add llama index's ollama embedding and fix type hint of embedding_model (#1071)
* add llama index's ollama embedding and fix type hint of embedding_model * unify the way load embedding models * add ollama to supporting model types * fix test code * add test_load_embedding_model_from_dict * remove unused import
1 parent b4fa576 commit cb406f2

File tree

17 files changed

+70
-40
lines changed

17 files changed

+70
-40
lines changed

autorag/data/chunk/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import pandas as pd
66

7-
from autorag.embedding.base import embedding_models
7+
from autorag.embedding.base import EmbeddingModel
88
from autorag.data import chunk_modules, sentence_splitter_modules
99
from autorag.utils import result_to_dataframe
1010

@@ -80,7 +80,7 @@ def get_embedding_model(_embed_model_str: str, _module_type: str):
8080
if _embed_model_str == "openai":
8181
if _module_type == "langchain_chunk":
8282
_embed_model_str = "openai_langchain"
83-
return embedding_models[_embed_model_str]()
83+
return EmbeddingModel.load(_embed_model_str)()
8484

8585
# Add embed_model to kwargs
8686
embedding_available_methods = ["semantic_llama_index", "semantic_langchain"]

autorag/embedding/base.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
88
from llama_index.embeddings.openai import OpenAIEmbedding
99
from llama_index.embeddings.openai import OpenAIEmbeddingModelType
10+
from llama_index.embeddings.ollama import OllamaEmbedding
1011
from langchain_openai.embeddings import OpenAIEmbeddings
1112

1213
from autorag import LazyInit
@@ -35,6 +36,7 @@ def _get_vector(self) -> List[float]:
3536
"mock": LazyInit(MockEmbeddingRandom, embed_dim=768),
3637
# langchain
3738
"openai_langchain": LazyInit(OpenAIEmbeddings),
39+
"ollama": LazyInit(OllamaEmbedding),
3840
}
3941

4042
try:
@@ -67,11 +69,13 @@ def _get_vector(self) -> List[float]:
6769

6870
class EmbeddingModel:
6971
@staticmethod
70-
def load(config: Union[str, List[Dict]]):
72+
def load(config: Union[str, Dict, List[Dict]]):
7173
if isinstance(config, str):
7274
return EmbeddingModel.load_from_str(config)
73-
elif isinstance(config, list):
75+
elif isinstance(config, dict):
7476
return EmbeddingModel.load_from_dict(config)
77+
elif isinstance(config, list):
78+
return EmbeddingModel.load_from_list(config)
7579
else:
7680
raise ValueError("Invalid type of config")
7781

@@ -83,11 +87,17 @@ def load_from_str(name: str):
8387
raise ValueError(f"Embedding model '{name}' is not supported")
8488

8589
@staticmethod
86-
def load_from_dict(option: List[dict]):
90+
def load_from_list(option: List[dict]):
91+
if len(option) != 1:
92+
raise ValueError("Only one embedding model is supported")
93+
return EmbeddingModel.load_from_dict(option[0])
94+
95+
@staticmethod
96+
def load_from_dict(option: dict):
8797
def _check_keys(target: dict):
8898
if "type" not in target or "model_name" not in target:
8999
raise ValueError("Both 'type' and 'model_name' must be provided")
90-
if target["type"] not in ["openai", "huggingface", "mock"]:
100+
if target["type"] not in ["openai", "huggingface", "mock", "ollama"]:
91101
raise ValueError(
92102
f"Embedding model type '{target['type']}' is not supported"
93103
)
@@ -102,17 +112,16 @@ def _get_huggingface_class():
102112
return None
103113
return getattr(module, "HuggingFaceEmbedding", None)
104114

105-
if len(option) != 1:
106-
raise ValueError("Only one embedding model is supported")
107-
_check_keys(option[0])
115+
_check_keys(option)
108116

109-
model_options = option[0]
117+
model_options = option
110118
model_type = model_options.pop("type")
111119

112120
embedding_map = {
113121
"openai": OpenAIEmbedding,
114122
"mock": MockEmbeddingRandom,
115123
"huggingface": _get_huggingface_class(),
124+
"ollama": OllamaEmbedding,
116125
}
117126

118127
embedding_class = embedding_map.get(model_type)

autorag/evaluation/util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from copy import deepcopy
22
from typing import Union, List, Dict, Tuple, Any
33

4-
from autorag.embedding.base import embedding_models
4+
from autorag.embedding.base import EmbeddingModel
55

66

77
def cast_metrics(
@@ -38,6 +38,6 @@ def cast_metrics(
3838

3939
def cast_embedding_model(key, value):
4040
if key == "embedding_model":
41-
return key, embedding_models[value]()
41+
return key, EmbeddingModel.load(value)()
4242
else:
4343
return key, value

autorag/nodes/passageaugmenter/prev_next_augmenter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from typing import List
1+
from typing import List, Union
22

33
import numpy as np
44
import pandas as pd
55

6-
from autorag.embedding.base import embedding_models
6+
from autorag.embedding.base import EmbeddingModel
77
from autorag.evaluation.metric.util import calculate_cosine_similarity
88
from autorag.nodes.passageaugmenter.base import BasePassageAugmenter
99
from autorag.utils.util import (
@@ -17,7 +17,7 @@
1717

1818
class PrevNextPassageAugmenter(BasePassageAugmenter):
1919
def __init__(
20-
self, project_dir: str, embedding_model: str = "openai", *args, **kwargs
20+
self, project_dir: str, embedding_model: Union[str, dict] = "openai", *args, **kwargs
2121
):
2222
"""
2323
Initialize the PrevNextPassageAugmenter module.
@@ -35,7 +35,7 @@ def __init__(
3535
self.slim_corpus_df = slim_corpus_df
3636

3737
# init embedding model
38-
self.embedding_model = embedding_models[embedding_model]()
38+
self.embedding_model = EmbeddingModel.load(embedding_model)()
3939

4040
def __del__(self):
4141
del self.embedding_model

autorag/nodes/passagefilter/similarity_percentile_cutoff.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
import pandas as pd
66

7-
from autorag.embedding.base import embedding_models
7+
from autorag.embedding.base import EmbeddingModel
88
from autorag.evaluation.metric.util import calculate_cosine_similarity
99
from autorag.nodes.passagefilter.base import BasePassageFilter
1010
from autorag.nodes.passagefilter.similarity_threshold_cutoff import (
@@ -24,8 +24,8 @@ def __init__(self, project_dir: Union[str, Path], *args, **kwargs):
2424
Default is "openai" which is OpenAI text-embedding-ada-002 embedding model.
2525
"""
2626
super().__init__(project_dir, *args, **kwargs)
27-
embedding_model_str = kwargs.pop("embedding_model", "openai")
28-
self.embedding_model = embedding_models[embedding_model_str]()
27+
embedding_model = kwargs.pop("embedding_model", "openai")
28+
self.embedding_model = EmbeddingModel.load(embedding_model)()
2929

3030
def __del__(self):
3131
super().__del__()

autorag/nodes/passagefilter/similarity_threshold_cutoff.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import pandas as pd
55

6-
from autorag.embedding.base import embedding_models
6+
from autorag.embedding.base import EmbeddingModel
77
from autorag.evaluation.metric.util import calculate_cosine_similarity
88
from autorag.nodes.passagefilter.base import BasePassageFilter
99
from autorag.utils.util import (
@@ -24,8 +24,8 @@ def __init__(self, project_dir: str, *args, **kwargs):
2424
Default is "openai" which is OpenAI text-embedding-ada-002 embedding model.
2525
"""
2626
super().__init__(project_dir, *args, **kwargs)
27-
embedding_model_str = kwargs.get("embedding_model", "openai")
28-
self.embedding_model = embedding_models[embedding_model_str]()
27+
embedding_model= kwargs.get("embedding_model", "openai")
28+
self.embedding_model = EmbeddingModel.load(embedding_model)()
2929

3030
def __del__(self):
3131
del self.embedding_model

autorag/schema/module.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@ def from_dict(cls, module_dict: Dict) -> "Module":
2222
module_type = _module_dict.pop("module_type")
2323
module_params = _module_dict
2424
return cls(module_type, module_params)
25+

autorag/utils/util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def make_combinations(target_dict: Dict[str, Any]) -> List[Dict[str, Any]]:
153153
)
154154
)
155155

156+
156157
def delete_duplicate(x):
157158
def is_hashable(obj):
158159
try:

autorag/vectordb/chroma.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional, Dict, Tuple
1+
from typing import List, Optional, Dict, Tuple, Union
22

33
from chromadb import (
44
EphemeralClient,
@@ -18,7 +18,7 @@
1818
class Chroma(BaseVectorStore):
1919
def __init__(
2020
self,
21-
embedding_model: str,
21+
embedding_model: Union[str, List[dict]],
2222
collection_name: str,
2323
embedding_batch: int = 100,
2424
client_type: str = "persistent",

autorag/vectordb/couchbase.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from couchbase.cluster import Cluster
77
from couchbase.options import ClusterOptions
88

9-
from typing import List, Tuple, Optional
9+
from typing import List, Tuple, Optional, Union
1010

1111
from autorag.utils.util import make_batch
1212
from autorag.vectordb import BaseVectorStore
@@ -17,7 +17,7 @@
1717
class Couchbase(BaseVectorStore):
1818
def __init__(
1919
self,
20-
embedding_model: str,
20+
embedding_model: Union[str, List[dict]],
2121
bucket_name: str,
2222
scope_name: str,
2323
collection_name: str,

0 commit comments

Comments
 (0)