diff --git a/docs/docs/cli_v2.md b/docs/docs/cli_v2.md index 0bae7fce..cca8d59f 100644 --- a/docs/docs/cli_v2.md +++ b/docs/docs/cli_v2.md @@ -34,6 +34,11 @@ $ infinity_emb v2 --help │ [env var: │ │ `INFINITY_BATCH_SIZE`] │ │ [default: 32] │ +│ --dimensions INTEGER default dimensions for │ +│ inference │ +│ [env var: │ +│ `INFINITY_DIMENSIONS`] │ +│ [default: 0] │ │ --revision TEXT huggingface model repo │ │ revision. │ │ [env var: `INFINITY_REVISION`] │ diff --git a/libs/infinity_emb/infinity_emb/args.py b/libs/infinity_emb/infinity_emb/args.py index fde57081..a931c61f 100644 --- a/libs/infinity_emb/infinity_emb/args.py +++ b/libs/infinity_emb/infinity_emb/args.py @@ -34,6 +34,7 @@ class EngineArgs: Args: model_name_or_path, str: Defaults to "michaelfeil/bge-small-en-v1.5". batch_size, int: Defaults to 32. + dimensions, int: Defaults to 0 (no matryoshka slicing). revision, str: Defaults to None. trust_remote_code, bool: Defaults to True. engine, InferenceEngine or str: backend for inference. @@ -54,6 +55,7 @@ class EngineArgs: model_name_or_path: str = MANAGER.model_id[0] batch_size: int = MANAGER.batch_size[0] + dimensions: int = MANAGER.dimensions[0] revision: Optional[str] = MANAGER.revision[0] trust_remote_code: bool = MANAGER.trust_remote_code[0] engine: InferenceEngine = InferenceEngine[MANAGER.engine[0]] @@ -150,6 +152,7 @@ def from_env(cls) -> list["EngineArgs"]: EngineArgs( model_name_or_path=model_name_or_path, batch_size=batch_size, + dimensions=dimensions, revision=revision, trust_remote_code=trust_remote_code, engine=engine, @@ -165,9 +168,10 @@ def from_env(cls) -> list["EngineArgs"]: onnx_disable_optimize=onnx_disable_optimize, onnx_do_not_prefer_quantized=onnx_do_not_prefer_quantized ) - for model_name_or_path, batch_size, revision, trust_remote_code, engine, model_warmup, device, compile, bettertransformer, dtype, pooling_method, lengths_via_tokenize, embedding_dtype, served_model_name,onnx_disable_optimize,onnx_do_not_prefer_quantized in zip_longest( + for model_name_or_path, batch_size, dimensions, revision, trust_remote_code, engine, model_warmup, device, compile, bettertransformer, dtype, pooling_method, lengths_via_tokenize, embedding_dtype, served_model_name,onnx_disable_optimize,onnx_do_not_prefer_quantized in zip_longest( MANAGER.model_id, MANAGER.batch_size, + MANAGER.dimensions, MANAGER.revision, MANAGER.trust_remote_code, MANAGER.engine, diff --git a/libs/infinity_emb/infinity_emb/cli.py b/libs/infinity_emb/infinity_emb/cli.py index 58b41810..deefaffc 100644 --- a/libs/infinity_emb/infinity_emb/cli.py +++ b/libs/infinity_emb/infinity_emb/cli.py @@ -113,6 +113,7 @@ def v1( model_name_or_path: str = MANAGER.model_id[0], served_model_name: str = MANAGER.served_model_name[0], batch_size: int = MANAGER.batch_size[0], + dimensions: int = MANAGER.dimensions[0], revision: str = MANAGER.revision[0], trust_remote_code: bool = MANAGER.trust_remote_code[0], redirect_slash: str = MANAGER.redirect_slash, @@ -153,6 +154,7 @@ def v1( model_id=[model_name_or_path], served_model_name=[served_model_name], # type: ignore batch_size=[batch_size], + dimensions=[dimensions], revision=[revision], # type: ignore trust_remote_code=[trust_remote_code], engine=[engine], @@ -192,6 +194,9 @@ def v2( batch_size: list[int] = typer.Option( **_construct("batch_size"), help="maximum batch size for inference" ), + dimensions: list[int] = typer.Option( + **_construct("dimensions"), help="default dimensions for inference" + ), revision: list[str] = typer.Option( **_construct("revision"), help="huggingface model repo revision." ), @@ -293,6 +298,7 @@ def v2( Defaults to `INFINITY_MODEL_ID` served_model_name, list[str]: "", e.g. ["bge-small-en-v1.5"] batch_size, list[int]: batch size for forward pass. + dimensions, list[int]: default dimensions for inference. revision: list[str]: revision of the model. trust_remote_code, list[bool]: trust remote code. url_prefix, str: prefix for api. typically "". @@ -326,6 +332,7 @@ def v2( length=len(model_id), model_name_or_path=model_id, batch_size=batch_size, + dimensions=dimensions, revision=revision, trust_remote_code=trust_remote_code, engine=engine, diff --git a/libs/infinity_emb/infinity_emb/engine.py b/libs/infinity_emb/infinity_emb/engine.py index 153e15ba..bbc9dcd2 100644 --- a/libs/infinity_emb/infinity_emb/engine.py +++ b/libs/infinity_emb/infinity_emb/engine.py @@ -88,6 +88,7 @@ async def astart(self): self.running = True self._batch_handler = BatchHandler( max_batch_size=self._engine_args.batch_size, + matryoshka_dim=self._engine_args.dimensions, model_replicas=self._model_replicas, # batch_delay=self._min_inference_t / 2, vector_disk_cache_path=self._engine_args.vector_disk_cache_path, diff --git a/libs/infinity_emb/infinity_emb/env.py b/libs/infinity_emb/infinity_emb/env.py index 48833e47..a3258b8f 100644 --- a/libs/infinity_emb/infinity_emb/env.py +++ b/libs/infinity_emb/infinity_emb/env.py @@ -107,6 +107,12 @@ def batch_size(self): self._optional_infinity_var_multiple("batch_size", default=["32"]) ) + @cached_property + def dimensions(self): + return self._to_int_multiple( + self._optional_infinity_var_multiple("dimensions", default=["0"]) + ) + @cached_property def revision(self): return self._optional_infinity_var_multiple("revision", default=[""]) diff --git a/libs/infinity_emb/infinity_emb/inference/batch_handler.py b/libs/infinity_emb/infinity_emb/inference/batch_handler.py index 24a7c49b..0584d277 100644 --- a/libs/infinity_emb/infinity_emb/inference/batch_handler.py +++ b/libs/infinity_emb/infinity_emb/inference/batch_handler.py @@ -79,6 +79,7 @@ def __init__( self, model_replicas: list["BaseTypeHint"], max_batch_size: int, + matryoshka_dim: Optional[int] = None, max_queue_wait: int = MANAGER.queue_size, batch_delay: float = 5e-3, vector_disk_cache_path: str = "", @@ -92,6 +93,7 @@ def __init__( Args: model (BaseTransformer): the base class of the model to be used max_batch_size (int): max batch size of dynamic batch size + matryoshka_dim (int, optional): default dimensions for matryoshka slicing. max_queue_wait (int, optional): max items to queue in the batch, default 32_000 batch_delay (float, optional): sleep in seconds, wait time for pre/post methods. Best result: setting to 1/2 the minimal expected @@ -112,6 +114,7 @@ def __init__( self._result_queue: Queue = Queue(8) self.max_batch_size = max_batch_size + self.matryoshka_dim = matryoshka_dim self._verbose = verbose self.batch_delay = batch_delay @@ -172,6 +175,7 @@ async def embed( input_sentences = [EmbeddingSingle(sentence=s) for s in sentences] embeddings, usage = await self._schedule(input_sentences) + matryoshka_dim = matryoshka_dim if matryoshka_dim else self.matryoshka_dim return matryososka_slice(embeddings, matryoshka_dim), usage async def rerank( @@ -278,6 +282,7 @@ async def image_embed( items = await resolve_images(images) embeddings, usage = await self._schedule(items) + matryoshka_dim = matryoshka_dim if matryoshka_dim else self.matryoshka_dim return matryososka_slice(embeddings, matryoshka_dim), usage async def audio_embed( @@ -308,6 +313,7 @@ async def audio_embed( getattr(self.model_worker[0]._model, "sampling_rate", -42), ) embeddings, usage = await self._schedule(items) + matryoshka_dim = matryoshka_dim if matryoshka_dim else self.matryoshka_dim return matryososka_slice(embeddings, matryoshka_dim), usage async def _schedule(self, list_queueitem: Sequence[AbstractSingle]) -> tuple[list[Any], int]: diff --git a/libs/infinity_emb/infinity_emb/infinity_server.py b/libs/infinity_emb/infinity_emb/infinity_server.py index 382cc6e8..b73e821b 100644 --- a/libs/infinity_emb/infinity_emb/infinity_server.py +++ b/libs/infinity_emb/infinity_emb/infinity_server.py @@ -57,6 +57,7 @@ def create_server( permissive_cors: bool = MANAGER.permissive_cors, api_key: str = MANAGER.api_key, proxy_root_path: str = MANAGER.proxy_root_path, + dimensions: int = MANAGER.dimensions, ): """ creates the FastAPI server for a set of EngineArgs. diff --git a/libs/infinity_emb/tests/end_to_end/test_api_with_dummymodel.py b/libs/infinity_emb/tests/end_to_end/test_api_with_dummymodel.py index 1e8d1aa4..16d86ef4 100644 --- a/libs/infinity_emb/tests/end_to_end/test_api_with_dummymodel.py +++ b/libs/infinity_emb/tests/end_to_end/test_api_with_dummymodel.py @@ -19,6 +19,8 @@ PREFIX = "" MODEL_NAME = "dummy-number-1" MODEL_NAME_2 = "dummy-number-2" +MODEL_NAME_3 = "dummy-number-3" +DEFAULT_DIMENSIONS = 5 BATCH_SIZE = 16 PATH_OPENAPI = pathlib.Path(__file__).parent.parent.parent.parent.parent.joinpath( @@ -38,6 +40,12 @@ batch_size=BATCH_SIZE, engine=InferenceEngine.debugengine, ), + EngineArgs( + model_name_or_path=MODEL_NAME_3, + batch_size=BATCH_SIZE, + dimensions=DEFAULT_DIMENSIONS, + engine=InferenceEngine.debugengine, + ), ], ) @@ -193,3 +201,24 @@ async def test_matryoshka_embedding(client): for embedding, sentence in zip(rdata["data"], inp): assert len(sentence) == embedding["embedding"][0] assert len(embedding["embedding"]) == matryoshka_dim + + +@pytest.mark.anyio +async def test_matryoshka_embedding_default_dimensions(client): + possible_inputs = [ + ["This is a test sentence."], + ["This is a test sentence.", "This is another test sentence."], + ] + for inp in possible_inputs: + response = await client.post( + f"{PREFIX}/embeddings", + json=dict(input=inp, model=MODEL_NAME_3), + ) + assert response.status_code == 200, f"{response.status_code}, {response.text}" + rdata = response.json() + assert "data" in rdata and isinstance(rdata["data"], list) + assert all("embedding" in d for d in rdata["data"]) + assert len(rdata["data"]) == len(inp) + for embedding, sentence in zip(rdata["data"], inp): + assert len(sentence) == embedding["embedding"][0] + assert len(embedding["embedding"]) == DEFAULT_DIMENSIONS