Skip to content

Commit 8c6ee33

Browse files
authored
fix: Support custom models in openai text embedder (#5525)
## Changes Made Currently the openai text embedder asserts that the model must be an openai model, however this does not work if user passes in a custom BASE_URL that routes to an openai compatible server of a open source model like qwen. This PR elides the model check if user passes in BASE_URL, and also allows user to pass in custom `embedding_dimensions` ## Related Issues <!-- Link to related GitHub issues, e.g., "Closes #123" --> ## Checklist - [ ] Documented in API Docs (if applicable) - [ ] Documented in User Guide (if applicable) - [ ] If adding a new documentation page, doc is added to `docs/mkdocs.yml` navigation - [ ] Documentation builds and is formatted properly
1 parent c950538 commit 8c6ee33

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

daft/ai/openai/protocols/text_embedder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class OpenAITextEmbedderDescriptor(TextEmbedderDescriptor):
6363
model_options: Options
6464

6565
def __post_init__(self) -> None:
66-
if self.model_name not in _models:
66+
if self.provider_options.get("base_url") is None and self.model_name not in _models:
6767
supported_models = ", ".join(_models.keys())
6868
raise ValueError(
6969
f"Unsupported OpenAI embedding model '{self.model_name}', expected one of: {supported_models}"
@@ -79,10 +79,12 @@ def get_options(self) -> Options:
7979
return self.model_options
8080

8181
def get_dimensions(self) -> EmbeddingDimensions:
82+
if self.model_options.get("embedding_dimensions") is not None:
83+
return EmbeddingDimensions(size=self.model_options["embedding_dimensions"], dtype=DataType.float32())
8284
return _models[self.model_name].dimensions
8385

8486
def get_udf_options(self) -> UDFOptions:
85-
return get_http_udf_options()
87+
return UDFOptions(concurrency=None, num_gpus=None)
8688

8789
def is_async(self) -> bool:
8890
return True

0 commit comments

Comments
 (0)