Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/embedders/classification/contextual.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,15 +217,16 @@ class PrivatemodeAISentenceEmbedder(SentenceEmbedder):
def __init__(
self,
batch_size: int = 128,
model_name: str = "intfloat/multilingual-e5-large-instruct",
model_name: str = "qwen3-embedding-4b",
hf_model_name: str = "boboliu/Qwen3-Embedding-4B-W4A16-G128",
):
"""
Embeds documents using privatemode ai proxy via OpenAI classes.
Note that the model and api key are currently hardcoded since they aren't configurable.

Args:
batch_size (int, optional): Defines the number of conversions after which the embedder yields. Defaults to 128.
model_name (str, optional): Name of the embedding model from Privatemode AI (e.g. intfloat/multilingual-e5-large-instruct). Defaults to "intfloat/multilingual-e5-large-instruct".
model_name (str, optional): Name of the embedding model from Privatemode AI (e.g. intfloat/multilingual-e5-large-instruct). Defaults to "qwen3-embedding-4b".

Raises:
Exception: If you use Azure, you need to provide api_type, api_version and api_base.
Expand All @@ -238,8 +239,8 @@ def __init__(
api_key="dummy", # Set in proxy
base_url=PRIVATEMODE_AI_URL,
)
# for trimming the length of the text if > 512 tokens
self._auto_tokenizer = AutoTokenizer.from_pretrained(self.model_name)
# for trimming the length of the text if > 32000 tokens
self._auto_tokenizer = AutoTokenizer.from_pretrained(hf_model_name)

def _encode(
self, documents: List[Union[str, Doc]], fit_model: bool
Expand Down Expand Up @@ -278,7 +279,7 @@ def dump(self, project_id: str, embedding_id: str) -> None:
export_file.parent.mkdir(parents=True, exist_ok=True)
util.write_json(self.to_json(), export_file, indent=2)

def _trim_length(self, text: str, max_length: int = 512) -> str:
def _trim_length(self, text: str, max_length: int = 32000) -> str:
tokens = self._auto_tokenizer(
text,
truncation=True,
Expand Down