Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/docs/cli_v2.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ $ infinity_emb v2 --help
│ [env var: │
│ `INFINITY_BATCH_SIZE`] │
│ [default: 32] │
│ --dimensions INTEGER default dimensions for │
│ inference │
│ [env var: │
│ `INFINITY_DIMENSIONS`] │
│ [default: 0] │
│ --revision TEXT huggingface model repo │
│ revision. │
│ [env var: `INFINITY_REVISION`] │
Expand Down
6 changes: 5 additions & 1 deletion libs/infinity_emb/infinity_emb/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class EngineArgs:
Args:
model_name_or_path, str: Defaults to "michaelfeil/bge-small-en-v1.5".
batch_size, int: Defaults to 32.
dimensions, int: Defaults to 0 (no matryoshka slicing).
revision, str: Defaults to None.
trust_remote_code, bool: Defaults to True.
engine, InferenceEngine or str: backend for inference.
Expand All @@ -54,6 +55,7 @@ class EngineArgs:

model_name_or_path: str = MANAGER.model_id[0]
batch_size: int = MANAGER.batch_size[0]
dimensions: int = MANAGER.dimensions[0]
revision: Optional[str] = MANAGER.revision[0]
trust_remote_code: bool = MANAGER.trust_remote_code[0]
engine: InferenceEngine = InferenceEngine[MANAGER.engine[0]]
Expand Down Expand Up @@ -148,6 +150,7 @@ def from_env(cls) -> list["EngineArgs"]:
EngineArgs(
model_name_or_path=model_name_or_path,
batch_size=batch_size,
dimensions=dimensions,
revision=revision,
trust_remote_code=trust_remote_code,
engine=engine,
Expand All @@ -161,9 +164,10 @@ def from_env(cls) -> list["EngineArgs"]:
embedding_dtype=embedding_dtype,
served_model_name=served_model_name,
)
for model_name_or_path, batch_size, revision, trust_remote_code, engine, model_warmup, device, compile, bettertransformer, dtype, pooling_method, lengths_via_tokenize, embedding_dtype, served_model_name in zip_longest(
for model_name_or_path, batch_size, dimensions, revision, trust_remote_code, engine, model_warmup, device, compile, bettertransformer, dtype, pooling_method, lengths_via_tokenize, embedding_dtype, served_model_name in zip_longest(
MANAGER.model_id,
MANAGER.batch_size,
MANAGER.dimensions,
MANAGER.revision,
MANAGER.trust_remote_code,
MANAGER.engine,
Expand Down
7 changes: 7 additions & 0 deletions libs/infinity_emb/infinity_emb/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def v1(
model_name_or_path: str = MANAGER.model_id[0],
served_model_name: str = MANAGER.served_model_name[0],
batch_size: int = MANAGER.batch_size[0],
dimensions: int = MANAGER.dimensions[0],
revision: str = MANAGER.revision[0],
trust_remote_code: bool = MANAGER.trust_remote_code[0],
redirect_slash: str = MANAGER.redirect_slash,
Expand Down Expand Up @@ -153,6 +154,7 @@ def v1(
model_id=[model_name_or_path],
served_model_name=[served_model_name], # type: ignore
batch_size=[batch_size],
dimensions=[dimensions],
revision=[revision], # type: ignore
trust_remote_code=[trust_remote_code],
engine=[engine],
Expand Down Expand Up @@ -192,6 +194,9 @@ def v2(
batch_size: list[int] = typer.Option(
**_construct("batch_size"), help="maximum batch size for inference"
),
dimensions: list[int] = typer.Option(
**_construct("dimensions"), help="default dimensions for inference"
),
revision: list[str] = typer.Option(
**_construct("revision"), help="huggingface model repo revision."
),
Expand Down Expand Up @@ -285,6 +290,7 @@ def v2(
Defaults to `INFINITY_MODEL_ID`
served_model_name, list[str]: "", e.g. ["bge-small-en-v1.5"]
batch_size, list[int]: batch size for forward pass.
dimensions, list[int]: default dimensions for inference.
revision: list[str]: revision of the model.
trust_remote_code, list[bool]: trust remote code.
url_prefix, str: prefix for api. typically "".
Expand Down Expand Up @@ -316,6 +322,7 @@ def v2(
length=len(model_id),
model_name_or_path=model_id,
batch_size=batch_size,
dimensions=dimensions,
revision=revision,
trust_remote_code=trust_remote_code,
engine=engine,
Expand Down
1 change: 1 addition & 0 deletions libs/infinity_emb/infinity_emb/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ async def astart(self):
self.running = True
self._batch_handler = BatchHandler(
max_batch_size=self._engine_args.batch_size,
matryoshka_dim=self._engine_args.dimensions,
model_replicas=self._model_replicas,
# batch_delay=self._min_inference_t / 2,
vector_disk_cache_path=self._engine_args.vector_disk_cache_path,
Expand Down
6 changes: 6 additions & 0 deletions libs/infinity_emb/infinity_emb/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ def batch_size(self):
self._optional_infinity_var_multiple("batch_size", default=["32"])
)

@cached_property
def dimensions(self):
return self._to_int_multiple(
self._optional_infinity_var_multiple("dimensions", default=["0"])
)

@cached_property
def revision(self):
return self._optional_infinity_var_multiple("revision", default=[""])
Expand Down
6 changes: 6 additions & 0 deletions libs/infinity_emb/infinity_emb/inference/batch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
self,
model_replicas: list["BaseTypeHint"],
max_batch_size: int,
matryoshka_dim: Optional[int] = None,
max_queue_wait: int = MANAGER.queue_size,
batch_delay: float = 5e-3,
vector_disk_cache_path: str = "",
Expand All @@ -92,6 +93,7 @@ def __init__(
Args:
model (BaseTransformer): the base class of the model to be used
max_batch_size (int): max batch size of dynamic batch size
matryoshka_dim (int, optional): default dimensions for matryoshka slicing.
max_queue_wait (int, optional): max items to queue in the batch, default 32_000
batch_delay (float, optional): sleep in seconds, wait time for pre/post methods.
Best result: setting to 1/2 the minimal expected
Expand All @@ -112,6 +114,7 @@ def __init__(
self._result_queue: Queue = Queue(8)

self.max_batch_size = max_batch_size
self.matryoshka_dim = matryoshka_dim
self._verbose = verbose
self.batch_delay = batch_delay

Expand Down Expand Up @@ -172,6 +175,7 @@ async def embed(
input_sentences = [EmbeddingSingle(sentence=s) for s in sentences]

embeddings, usage = await self._schedule(input_sentences)
matryoshka_dim = matryoshka_dim if matryoshka_dim else self.matryoshka_dim
return matryososka_slice(embeddings, matryoshka_dim), usage

async def rerank(
Expand Down Expand Up @@ -278,6 +282,7 @@ async def image_embed(

items = await resolve_images(images)
embeddings, usage = await self._schedule(items)
matryoshka_dim = matryoshka_dim if matryoshka_dim else self.matryoshka_dim
return matryososka_slice(embeddings, matryoshka_dim), usage

async def audio_embed(
Expand Down Expand Up @@ -308,6 +313,7 @@ async def audio_embed(
getattr(self.model_worker[0]._model, "sampling_rate", -42),
)
embeddings, usage = await self._schedule(items)
matryoshka_dim = matryoshka_dim if matryoshka_dim else self.matryoshka_dim
return matryososka_slice(embeddings, matryoshka_dim), usage

async def _schedule(self, list_queueitem: Sequence[AbstractSingle]) -> tuple[list[Any], int]:
Expand Down
1 change: 1 addition & 0 deletions libs/infinity_emb/infinity_emb/infinity_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def create_server(
permissive_cors: bool = MANAGER.permissive_cors,
api_key: str = MANAGER.api_key,
proxy_root_path: str = MANAGER.proxy_root_path,
dimensions: int = MANAGER.dimensions,
):
"""
creates the FastAPI server for a set of EngineArgs.
Expand Down
29 changes: 29 additions & 0 deletions libs/infinity_emb/tests/end_to_end/test_api_with_dummymodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
PREFIX = ""
MODEL_NAME = "dummy-number-1"
MODEL_NAME_2 = "dummy-number-2"
MODEL_NAME_3 = "dummy-number-3"
DEFAULT_DIMENSIONS = 5
BATCH_SIZE = 16

PATH_OPENAPI = pathlib.Path(__file__).parent.parent.parent.parent.parent.joinpath(
Expand All @@ -38,6 +40,12 @@
batch_size=BATCH_SIZE,
engine=InferenceEngine.debugengine,
),
EngineArgs(
model_name_or_path=MODEL_NAME_3,
batch_size=BATCH_SIZE,
dimensions=DEFAULT_DIMENSIONS,
engine=InferenceEngine.debugengine,
),
],
)

Expand Down Expand Up @@ -193,3 +201,24 @@ async def test_matryoshka_embedding(client):
for embedding, sentence in zip(rdata["data"], inp):
assert len(sentence) == embedding["embedding"][0]
assert len(embedding["embedding"]) == matryoshka_dim


@pytest.mark.anyio
async def test_matryoshka_embedding_default_dimensions(client):
possible_inputs = [
["This is a test sentence."],
["This is a test sentence.", "This is another test sentence."],
]
for inp in possible_inputs:
response = await client.post(
f"{PREFIX}/embeddings",
json=dict(input=inp, model=MODEL_NAME_3),
)
assert response.status_code == 200, f"{response.status_code}, {response.text}"
rdata = response.json()
assert "data" in rdata and isinstance(rdata["data"], list)
assert all("embedding" in d for d in rdata["data"])
assert len(rdata["data"]) == len(inp)
for embedding, sentence in zip(rdata["data"], inp):
assert len(sentence) == embedding["embedding"][0]
assert len(embedding["embedding"]) == DEFAULT_DIMENSIONS
Loading