Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/docs/cli_v2.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ $ infinity_emb v2 --help
│ [env var: │
│ `INFINITY_BATCH_SIZE`] │
│ [default: 32] │
│ --dimensions INTEGER default dimensions for │
│ inference │
│ [env var: │
│ `INFINITY_DIMENSIONS`] │
│ [default: 0] │
│ --revision TEXT huggingface model repo │
│ revision. │
│ [env var: `INFINITY_REVISION`] │
Expand Down
6 changes: 5 additions & 1 deletion libs/infinity_emb/infinity_emb/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class EngineArgs:
Args:
model_name_or_path, str: Defaults to "michaelfeil/bge-small-en-v1.5".
batch_size, int: Defaults to 32.
dimensions, int: Defaults to 0 (no matryoshka slicing).
revision, str: Defaults to None.
trust_remote_code, bool: Defaults to True.
engine, InferenceEngine or str: backend for inference.
Expand All @@ -54,6 +55,7 @@ class EngineArgs:

model_name_or_path: str = MANAGER.model_id[0]
batch_size: int = MANAGER.batch_size[0]
dimensions: int = MANAGER.dimensions[0]
revision: Optional[str] = MANAGER.revision[0]
trust_remote_code: bool = MANAGER.trust_remote_code[0]
engine: InferenceEngine = InferenceEngine[MANAGER.engine[0]]
Expand Down Expand Up @@ -150,6 +152,7 @@ def from_env(cls) -> list["EngineArgs"]:
EngineArgs(
model_name_or_path=model_name_or_path,
batch_size=batch_size,
dimensions=dimensions,
revision=revision,
trust_remote_code=trust_remote_code,
engine=engine,
Expand All @@ -165,9 +168,10 @@ def from_env(cls) -> list["EngineArgs"]:
onnx_disable_optimize=onnx_disable_optimize,
onnx_do_not_prefer_quantized=onnx_do_not_prefer_quantized
)
for model_name_or_path, batch_size, revision, trust_remote_code, engine, model_warmup, device, compile, bettertransformer, dtype, pooling_method, lengths_via_tokenize, embedding_dtype, served_model_name,onnx_disable_optimize,onnx_do_not_prefer_quantized in zip_longest(
for model_name_or_path, batch_size, dimensions, revision, trust_remote_code, engine, model_warmup, device, compile, bettertransformer, dtype, pooling_method, lengths_via_tokenize, embedding_dtype, served_model_name,onnx_disable_optimize,onnx_do_not_prefer_quantized in zip_longest(
MANAGER.model_id,
MANAGER.batch_size,
MANAGER.dimensions,
MANAGER.revision,
MANAGER.trust_remote_code,
MANAGER.engine,
Expand Down
7 changes: 7 additions & 0 deletions libs/infinity_emb/infinity_emb/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def v1(
model_name_or_path: str = MANAGER.model_id[0],
served_model_name: str = MANAGER.served_model_name[0],
batch_size: int = MANAGER.batch_size[0],
dimensions: int = MANAGER.dimensions[0],
revision: str = MANAGER.revision[0],
trust_remote_code: bool = MANAGER.trust_remote_code[0],
redirect_slash: str = MANAGER.redirect_slash,
Expand Down Expand Up @@ -153,6 +154,7 @@ def v1(
model_id=[model_name_or_path],
served_model_name=[served_model_name], # type: ignore
batch_size=[batch_size],
dimensions=[dimensions],
revision=[revision], # type: ignore
trust_remote_code=[trust_remote_code],
engine=[engine],
Expand Down Expand Up @@ -192,6 +194,9 @@ def v2(
batch_size: list[int] = typer.Option(
**_construct("batch_size"), help="maximum batch size for inference"
),
dimensions: list[int] = typer.Option(
**_construct("dimensions"), help="default dimensions for inference"
),
revision: list[str] = typer.Option(
**_construct("revision"), help="huggingface model repo revision."
),
Expand Down Expand Up @@ -293,6 +298,7 @@ def v2(
Defaults to `INFINITY_MODEL_ID`
served_model_name, list[str]: "", e.g. ["bge-small-en-v1.5"]
batch_size, list[int]: batch size for forward pass.
dimensions, list[int]: default dimensions for inference.
revision: list[str]: revision of the model.
trust_remote_code, list[bool]: trust remote code.
url_prefix, str: prefix for api. typically "".
Expand Down Expand Up @@ -326,6 +332,7 @@ def v2(
length=len(model_id),
model_name_or_path=model_id,
batch_size=batch_size,
dimensions=dimensions,
revision=revision,
trust_remote_code=trust_remote_code,
engine=engine,
Expand Down
1 change: 1 addition & 0 deletions libs/infinity_emb/infinity_emb/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ async def astart(self):
self.running = True
self._batch_handler = BatchHandler(
max_batch_size=self._engine_args.batch_size,
matryoshka_dim=self._engine_args.dimensions,
model_replicas=self._model_replicas,
# batch_delay=self._min_inference_t / 2,
vector_disk_cache_path=self._engine_args.vector_disk_cache_path,
Expand Down
6 changes: 6 additions & 0 deletions libs/infinity_emb/infinity_emb/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ def batch_size(self):
self._optional_infinity_var_multiple("batch_size", default=["32"])
)

@cached_property
def dimensions(self):
return self._to_int_multiple(
self._optional_infinity_var_multiple("dimensions", default=["0"])
)

@cached_property
def revision(self):
return self._optional_infinity_var_multiple("revision", default=[""])
Expand Down
6 changes: 6 additions & 0 deletions libs/infinity_emb/infinity_emb/inference/batch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
self,
model_replicas: list["BaseTypeHint"],
max_batch_size: int,
matryoshka_dim: Optional[int] = None,
max_queue_wait: int = MANAGER.queue_size,
batch_delay: float = 5e-3,
vector_disk_cache_path: str = "",
Expand All @@ -92,6 +93,7 @@ def __init__(
Args:
model (BaseTransformer): the base class of the model to be used
max_batch_size (int): max batch size of dynamic batch size
matryoshka_dim (int, optional): default dimensions for matryoshka slicing.
max_queue_wait (int, optional): max items to queue in the batch, default 32_000
batch_delay (float, optional): sleep in seconds, wait time for pre/post methods.
Best result: setting to 1/2 the minimal expected
Expand All @@ -112,6 +114,7 @@ def __init__(
self._result_queue: Queue = Queue(8)

self.max_batch_size = max_batch_size
self.matryoshka_dim = matryoshka_dim
self._verbose = verbose
self.batch_delay = batch_delay

Expand Down Expand Up @@ -172,6 +175,7 @@ async def embed(
input_sentences = [EmbeddingSingle(sentence=s) for s in sentences]

embeddings, usage = await self._schedule(input_sentences)
matryoshka_dim = matryoshka_dim if matryoshka_dim else self.matryoshka_dim
return matryososka_slice(embeddings, matryoshka_dim), usage

async def rerank(
Expand Down Expand Up @@ -278,6 +282,7 @@ async def image_embed(

items = await resolve_images(images)
embeddings, usage = await self._schedule(items)
matryoshka_dim = matryoshka_dim if matryoshka_dim else self.matryoshka_dim
return matryososka_slice(embeddings, matryoshka_dim), usage

async def audio_embed(
Expand Down Expand Up @@ -308,6 +313,7 @@ async def audio_embed(
getattr(self.model_worker[0]._model, "sampling_rate", -42),
)
embeddings, usage = await self._schedule(items)
matryoshka_dim = matryoshka_dim if matryoshka_dim else self.matryoshka_dim
return matryososka_slice(embeddings, matryoshka_dim), usage

async def _schedule(self, list_queueitem: Sequence[AbstractSingle]) -> tuple[list[Any], int]:
Expand Down
1 change: 1 addition & 0 deletions libs/infinity_emb/infinity_emb/infinity_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def create_server(
permissive_cors: bool = MANAGER.permissive_cors,
api_key: str = MANAGER.api_key,
proxy_root_path: str = MANAGER.proxy_root_path,
dimensions: int = MANAGER.dimensions,
):
"""
creates the FastAPI server for a set of EngineArgs.
Expand Down
29 changes: 29 additions & 0 deletions libs/infinity_emb/tests/end_to_end/test_api_with_dummymodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
PREFIX = ""
MODEL_NAME = "dummy-number-1"
MODEL_NAME_2 = "dummy-number-2"
MODEL_NAME_3 = "dummy-number-3"
DEFAULT_DIMENSIONS = 5
BATCH_SIZE = 16

PATH_OPENAPI = pathlib.Path(__file__).parent.parent.parent.parent.parent.joinpath(
Expand All @@ -38,6 +40,12 @@
batch_size=BATCH_SIZE,
engine=InferenceEngine.debugengine,
),
EngineArgs(
model_name_or_path=MODEL_NAME_3,
batch_size=BATCH_SIZE,
dimensions=DEFAULT_DIMENSIONS,
engine=InferenceEngine.debugengine,
),
],
)

Expand Down Expand Up @@ -193,3 +201,24 @@ async def test_matryoshka_embedding(client):
for embedding, sentence in zip(rdata["data"], inp):
assert len(sentence) == embedding["embedding"][0]
assert len(embedding["embedding"]) == matryoshka_dim


@pytest.mark.anyio
async def test_matryoshka_embedding_default_dimensions(client):
possible_inputs = [
["This is a test sentence."],
["This is a test sentence.", "This is another test sentence."],
]
for inp in possible_inputs:
response = await client.post(
f"{PREFIX}/embeddings",
json=dict(input=inp, model=MODEL_NAME_3),
)
assert response.status_code == 200, f"{response.status_code}, {response.text}"
rdata = response.json()
assert "data" in rdata and isinstance(rdata["data"], list)
assert all("embedding" in d for d in rdata["data"])
assert len(rdata["data"]) == len(inp)
for embedding, sentence in zip(rdata["data"], inp):
assert len(sentence) == embedding["embedding"][0]
assert len(embedding["embedding"]) == DEFAULT_DIMENSIONS