Skip to content

Commit b3764ce

Browse files
committed
feat: specify text embedding dim
1 parent 97bfd49 commit b3764ce

File tree

4 files changed

+42
-17
lines changed

4 files changed

+42
-17
lines changed

daft/ai/lm_studio/protocols/text_embedder.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from openai import OpenAI
77

88
from daft import DataType
9-
from daft.ai.openai.protocols.text_embedder import OpenAITextEmbedder, get_input_text_token_limit_for_model
9+
from daft.ai.openai.protocols.text_embedder import OpenAITextEmbedder, _models, get_input_text_token_limit_for_model
1010
from daft.ai.protocols import TextEmbedder, TextEmbedderDescriptor
1111
from daft.ai.typing import EmbeddingDimensions, EmbedTextOptions, Options, UDFOptions
1212
from daft.utils import from_dict
@@ -26,10 +26,19 @@ class LMStudioTextEmbedderDescriptor(TextEmbedderDescriptor):
2626
provider_name: str
2727
provider_options: OpenAIProviderOptions
2828
model_name: str
29+
dimensions: int | None = None
2930
embed_options: EmbedTextOptions = field(
3031
default_factory=lambda: EmbedTextOptions(batch_size=64, max_retries=3, on_error="raise")
3132
)
3233

34+
def __post_init__(self) -> None:
35+
if self.dimensions is None:
36+
return
37+
if self.model_name in _models and not _models[self.model_name].supports_overriding_dimensions:
38+
raise ValueError(f"Embedding model '{self.model_name}' does not support specifying dimensions")
39+
if "supports_overriding_dimensions" not in self.embed_options:
40+
self.embed_options["supports_overriding_dimensions"] = True
41+
3342
def get_provider(self) -> str:
3443
return "lm_studio"
3544

@@ -48,6 +57,8 @@ def is_async(self) -> bool:
4857
return True
4958

5059
def get_dimensions(self) -> EmbeddingDimensions:
60+
if self.dimensions is not None:
61+
return EmbeddingDimensions(size=self.dimensions, dtype=DataType.float32())
5162
try:
5263
client = OpenAI(**self.provider_options)
5364
response = client.embeddings.create(
@@ -72,6 +83,7 @@ def instantiate(self) -> TextEmbedder:
7283
provider_options=self.provider_options,
7384
model=self.model_name,
7485
embed_options=self.embed_options,
86+
dimensions=self.dimensions if self.embed_options.get("supports_overriding_dimensions", False) else None,
7587
provider_name=self.get_provider(),
7688
batch_token_limit=batch_token_limit,
7789
input_text_token_limit=input_text_token_limit,

daft/ai/lm_studio/provider.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import sys
4-
import warnings
54
from typing import TYPE_CHECKING
65

76
if sys.version_info < (3, 11):
@@ -50,14 +49,10 @@ def get_text_embedder(
5049
LMStudioTextEmbedderDescriptor,
5150
)
5251

53-
if dimensions is not None:
54-
warnings.warn(
55-
f"embed_text dimensions was specified but provider {self.name} currently ignores this property: see https://github.com/Eventual-Inc/Daft/issues/5555"
56-
)
57-
5852
return LMStudioTextEmbedderDescriptor(
5953
provider_name=self._name,
6054
provider_options=self._options,
6155
model_name=(model or self.DEFAULT_TEXT_EMBEDDER),
56+
dimensions=dimensions,
6257
embed_options=options,
6358
)

daft/ai/transformers/protocols/text_embedder.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,20 @@
2525
@dataclass
2626
class TransformersTextEmbedderDescriptor(TextEmbedderDescriptor):
2727
model: str
28+
dimensions: int | None = None
2829
embed_options: EmbedTextOptions = field(default_factory=lambda: EmbedTextOptions(batch_size=64))
2930

31+
def __post_init__(self) -> None:
32+
if self.dimensions is None:
33+
return
34+
if self.dimensions <= 0:
35+
raise ValueError("Embedding dimensions must be a positive integer.")
36+
dimensions = AutoConfig.from_pretrained(self.model, trust_remote_code=True).hidden_size
37+
if self.dimensions > dimensions:
38+
raise ValueError(
39+
f"Requested dimensions ({self.dimensions}) exceeds model output size ({dimensions}) for '{self.model}'."
40+
)
41+
3042
def get_provider(self) -> str:
3143
return "transformers"
3244

@@ -37,6 +49,8 @@ def get_options(self) -> Options:
3749
return dict(self.embed_options)
3850

3951
def get_dimensions(self) -> EmbeddingDimensions:
52+
if self.dimensions is not None:
53+
return EmbeddingDimensions(size=self.dimensions, dtype=DataType.float32())
4054
dimensions = AutoConfig.from_pretrained(self.model, trust_remote_code=True).hidden_size
4155
return EmbeddingDimensions(size=dimensions, dtype=DataType.float32())
4256

@@ -48,20 +62,26 @@ def get_udf_options(self) -> UDFOptions:
4862
return udf_options
4963

5064
def instantiate(self) -> TextEmbedder:
51-
return TransformersTextEmbedder(self.model, **self.embed_options)
65+
return TransformersTextEmbedder(self.model, dimensions=self.dimensions, **self.embed_options)
5266

5367

5468
class TransformersTextEmbedder(TextEmbedder):
5569
model: SentenceTransformer
5670
embed_options: EmbedTextOptions
5771

58-
def __init__(self, model_name_or_path: str, **embed_options: Unpack[EmbedTextOptions]):
72+
def __init__(
73+
self,
74+
model_name_or_path: str,
75+
dimensions: int | None = None,
76+
**embed_options: Unpack[EmbedTextOptions],
77+
):
5978
# Let SentenceTransformer handle device selection automatically.
6079
self.model = SentenceTransformer(model_name_or_path, trust_remote_code=True, backend="torch")
6180
self.model.eval()
6281
self.embed_options = embed_options
82+
self.dimensions = dimensions
6383

6484
def embed_text(self, text: list[str]) -> list[Embedding]:
6585
with torch.inference_mode():
66-
batch = self.model.encode(text, convert_to_numpy=True)
86+
batch = self.model.encode(text, convert_to_numpy=True, truncate_dim=self.dimensions)
6787
return list(batch)

daft/ai/transformers/provider.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import sys
4-
import warnings
54
from typing import TYPE_CHECKING, Any
65

76
if sys.version_info < (3, 11):
@@ -91,13 +90,12 @@ def get_text_embedder(
9190
TransformersTextEmbedderDescriptor,
9291
)
9392

94-
if dimensions is not None:
95-
warnings.warn(
96-
f"embed_text dimensions was specified but provider {self.name} currently ignores this property: see https://github.com/Eventual-Inc/Daft/issues/5555"
97-
)
98-
9993
embed_options: EmbedTextOptions = options
100-
return TransformersTextEmbedderDescriptor(model or self.DEFAULT_TEXT_EMBEDDER, embed_options=embed_options)
94+
return TransformersTextEmbedderDescriptor(
95+
model=model or self.DEFAULT_TEXT_EMBEDDER,
96+
dimensions=dimensions,
97+
embed_options=embed_options,
98+
)
10199

102100
def get_image_classifier(
103101
self,

0 commit comments

Comments
 (0)