|
22 | 22 | RERANK_BACKEND, |
23 | 23 | RERANK_SERVICE_MODEL_UID, |
24 | 24 | RERANK_SERVICE_URL, |
25 | | - VECTOR_SIZE, |
| 25 | + EMBEDDING_DIMENSIONS, |
26 | 26 | ) |
27 | 27 | from aperag.query.query import DocumentWithScore |
28 | 28 | from aperag.vectorstore.connector import VectorStoreConnectorAdaptor |
29 | 29 |
|
30 | 30 |
|
31 | 31 | class EmbeddingService(Embeddings): |
32 | | - def __init__(self, model_type): |
33 | | - if EMBEDDING_BACKEND == "local": |
| 32 | + def __init__(self, embedding_backend, model_type): |
| 33 | + if embedding_backend == "local": |
34 | 34 | self.model = EMBEDDING_MODEL_CLS.get(model_type)() |
35 | | - elif EMBEDDING_BACKEND == "xinference": |
| 35 | + elif embedding_backend == "xinference": |
36 | 36 | self.model = XinferenceEmbedding() |
37 | | - elif EMBEDDING_BACKEND == "openai": |
| 37 | + elif embedding_backend == "openai": |
38 | 38 | self.model = OpenAIEmbedding() |
39 | 39 | else: |
40 | 40 | raise Exception("Unsupported embedding backend") |
@@ -336,15 +336,24 @@ def get_embedding_model(model_type: str = "bge", load=True, **kwargs) -> {Embedd |
336 | 336 | return embedding_model_cache[model_type] |
337 | 337 |
|
338 | 338 | embedding_model = None |
339 | | - vector_size = VECTOR_SIZE.get(model_type, 1024) |
| 339 | + vector_size = get_embedding_dimension(EMBEDDING_BACKEND, model_type, EMBEDDING_SERVICE_MODEL) |
340 | 340 |
|
341 | 341 | if load: |
342 | | - embedding_model = EmbeddingService(model_type) |
| 342 | + embedding_model = EmbeddingService(EMBEDDING_BACKEND, model_type) |
343 | 343 | embedding_model_cache[model_type] = (embedding_model, vector_size) |
344 | 344 |
|
345 | 345 | return embedding_model, vector_size |
346 | 346 |
|
347 | 347 |
|
| 348 | +def get_embedding_dimension(embedding_backend: str, model_type: str, service_model: str = None) -> int: |
| 349 | + rules = EMBEDDING_DIMENSIONS.get(embedding_backend, {}) |
| 350 | + |
| 351 | + if embedding_backend == "openai": |
| 352 | + return rules.get(service_model, rules["__default__"]) |
| 353 | + |
| 354 | + return rules.get(model_type, rules["__default__"]) |
| 355 | + |
| 356 | + |
348 | 357 | async def rerank(message, results): |
349 | 358 | model = get_rerank_model() |
350 | 359 | results = await model.rank(message, results) |
|
0 commit comments