Skip to content

Commit 85686a8

Browse files
authored
Merge pull request #677 from apecloud/feature/support_dynamic_embedding_dimension
feat: support dynamic embedding size
2 parents 99f9d2a + fbe87cd commit 85686a8

File tree

2 files changed

+32
-11
lines changed

2 files changed

+32
-11
lines changed

aperag/readers/base_embedding.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,19 @@
2222
RERANK_BACKEND,
2323
RERANK_SERVICE_MODEL_UID,
2424
RERANK_SERVICE_URL,
25-
VECTOR_SIZE,
25+
EMBEDDING_DIMENSIONS,
2626
)
2727
from aperag.query.query import DocumentWithScore
2828
from aperag.vectorstore.connector import VectorStoreConnectorAdaptor
2929

3030

3131
class EmbeddingService(Embeddings):
32-
def __init__(self, model_type):
33-
if EMBEDDING_BACKEND == "local":
32+
def __init__(self, embedding_backend, model_type):
33+
if embedding_backend == "local":
3434
self.model = EMBEDDING_MODEL_CLS.get(model_type)()
35-
elif EMBEDDING_BACKEND == "xinference":
35+
elif embedding_backend == "xinference":
3636
self.model = XinferenceEmbedding()
37-
elif EMBEDDING_BACKEND == "openai":
37+
elif embedding_backend == "openai":
3838
self.model = OpenAIEmbedding()
3939
else:
4040
raise Exception("Unsupported embedding backend")
@@ -336,15 +336,24 @@ def get_embedding_model(model_type: str = "bge", load=True, **kwargs) -> {Embedd
336336
return embedding_model_cache[model_type]
337337

338338
embedding_model = None
339-
vector_size = VECTOR_SIZE.get(model_type, 1024)
339+
vector_size = get_embedding_dimension(EMBEDDING_BACKEND, model_type, EMBEDDING_SERVICE_MODEL)
340340

341341
if load:
342-
embedding_model = EmbeddingService(model_type)
342+
embedding_model = EmbeddingService(EMBEDDING_BACKEND, model_type)
343343
embedding_model_cache[model_type] = (embedding_model, vector_size)
344344

345345
return embedding_model, vector_size
346346

347347

348+
def get_embedding_dimension(embedding_backend: str, model_type: str, service_model: str = None) -> int:
349+
rules = EMBEDDING_DIMENSIONS.get(embedding_backend, {})
350+
351+
if embedding_backend == "openai":
352+
return rules.get(service_model, rules["__default__"])
353+
354+
return rules.get(model_type, rules["__default__"])
355+
356+
348357
async def rerank(message, results):
349358
model = get_rerank_model()
350359
results = await model.rank(message, results)

config/settings.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -260,10 +260,22 @@
260260
# xinference only needs model_uid, doesn't need model name
261261
RERANK_SERVICE_MODEL_UID = env.str("RERANK_SERVICE_MODEL_UID", default="")
262262

263-
VECTOR_SIZE = {
264-
"huggingface": 768,
265-
"text2vec": 768,
266-
"bge": 1024
263+
EMBEDDING_DIMENSIONS = {
264+
"local": {
265+
"huggingface": 768,
266+
"text2vec": 768,
267+
"bge": 1024,
268+
"__default__": 1024
269+
},
270+
"xinference": {
271+
"__default__": 1024
272+
},
273+
"openai": {
274+
"text-embedding-ada-002": 1536,
275+
"text-embedding-3-small": 1536,
276+
"text-embedding-3-large": 3072,
277+
"__default__": 1536
278+
}
267279
}
268280

269281
# Memory backend

0 commit comments

Comments
 (0)