Skip to content

Commit 4bc9153

Browse files
authored
[ENH]: fix Qwen EF hydration with custom prompts/tasks (#5808)
1 parent a7575cc commit 4bc9153

File tree

4 files changed

+51
-56
lines changed

4 files changed

+51
-56
lines changed

chromadb/test/utils/test_embedding_function_schemas.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,13 @@ def test_embedding_function_config_roundtrip(
5050
if ef_name == "chroma-cloud-qwen":
5151
from chromadb.utils.embedding_functions.chroma_cloud_qwen_embedding_function import (
5252
ChromaCloudQwenEmbeddingModel,
53-
ChromaCloudQwenEmbeddingTask,
5453
CHROMA_CLOUD_QWEN_DEFAULT_INSTRUCTIONS,
5554
)
55+
5656
mock_ef.get_config.return_value = {
5757
"api_key_env_var": "CHROMA_API_KEY",
5858
"model": ChromaCloudQwenEmbeddingModel.QWEN3_EMBEDDING_0p6B.value,
59-
"task": ChromaCloudQwenEmbeddingTask.NL_TO_CODE.value,
59+
"task": "nl_to_code",
6060
"instructions": CHROMA_CLOUD_QWEN_DEFAULT_INSTRUCTIONS,
6161
}
6262

chromadb/utils/embedding_functions/chroma_cloud_qwen_embedding_function.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,31 +5,34 @@
55
from chromadb.utils.embedding_functions.schemas import validate_config_schema
66
from enum import Enum
77

8+
89
class ChromaCloudQwenEmbeddingModel(Enum):
910
QWEN3_EMBEDDING_0p6B = "Qwen/Qwen3-Embedding-0.6B"
1011

11-
class ChromaCloudQwenEmbeddingTask(Enum):
12-
NL_TO_CODE = "nl_to_code"
1312

1413
class ChromaCloudQwenEmbeddingTarget(Enum):
1514
DOCUMENTS = "documents"
1615
QUERY = "query"
1716

18-
ChromaCloudQwenEmbeddingInstructions = Dict[ChromaCloudQwenEmbeddingTask, Dict[ChromaCloudQwenEmbeddingTarget, str]]
17+
18+
ChromaCloudQwenEmbeddingInstructions = Dict[
19+
str, Dict[ChromaCloudQwenEmbeddingTarget, str]
20+
]
1921

2022
CHROMA_CLOUD_QWEN_DEFAULT_INSTRUCTIONS: ChromaCloudQwenEmbeddingInstructions = {
21-
ChromaCloudQwenEmbeddingTask.NL_TO_CODE: {
23+
"nl_to_code": {
2224
ChromaCloudQwenEmbeddingTarget.DOCUMENTS: "",
2325
# Taken from https://github.com/QwenLM/Qwen3-Embedding/blob/main/evaluation/task_prompts.json
2426
ChromaCloudQwenEmbeddingTarget.QUERY: "Given a question about coding, retrieval code or passage that can solve user's question",
25-
}
27+
}
2628
}
2729

30+
2831
class ChromaCloudQwenEmbeddingFunction(EmbeddingFunction[Documents]):
2932
def __init__(
3033
self,
3134
model: ChromaCloudQwenEmbeddingModel,
32-
task: ChromaCloudQwenEmbeddingTask,
35+
task: str,
3336
instructions: ChromaCloudQwenEmbeddingInstructions = CHROMA_CLOUD_QWEN_DEFAULT_INSTRUCTIONS,
3437
api_key_env_var: str = "CHROMA_API_KEY",
3538
):
@@ -38,11 +41,11 @@ def __init__(
3841
3942
Args:
4043
model (ChromaCloudQwenEmbeddingModel): The specific Qwen model to use for embeddings.
41-
task (ChromaCloudQwenEmbeddingTask): The task for which embeddings are being generated.
44+
task (str): The task for which embeddings are being generated.
4245
instructions (ChromaCloudQwenEmbeddingInstructions, optional): A dictionary containing
4346
custom instructions to use for the specified Qwen model. Defaults to CHROMA_CLOUD_QWEN_DEFAULT_INSTRUCTIONS.
4447
api_key_env_var (str, optional): Environment variable name that contains your API key.
45-
Defaults to "CHROMA_API_KEY".
48+
Defaults to "CHROMA_API_KEY".
4649
"""
4750
try:
4851
import httpx
@@ -63,7 +66,10 @@ def __init__(
6366
self._api_url = "https://embed.trychroma.com"
6467
self._session = httpx.Client()
6568
self._session.headers.update(
66-
{"x-chroma-token": self.api_key, "x-chroma-embedding-model": self.model.value}
69+
{
70+
"x-chroma-token": self.api_key,
71+
"x-chroma-embedding-model": self.model.value,
72+
}
6773
)
6874

6975
def _parse_response(self, response: Any) -> Embeddings:
@@ -83,7 +89,6 @@ def _parse_response(self, response: Any) -> Embeddings:
8389

8490
return np.array(embeddings, dtype=np.float32)
8591

86-
8792
def __call__(self, input: Documents) -> Embeddings:
8893
"""
8994
Generate embeddings for the given documents.
@@ -98,7 +103,9 @@ def __call__(self, input: Documents) -> Embeddings:
98103
return []
99104

100105
payload: Dict[str, Union[str, Documents]] = {
101-
"instructions": self.instructions[self.task][ChromaCloudQwenEmbeddingTarget.DOCUMENTS],
106+
"instructions": self.instructions[self.task][
107+
ChromaCloudQwenEmbeddingTarget.DOCUMENTS
108+
],
102109
"texts": input,
103110
}
104111

@@ -114,7 +121,9 @@ def embed_query(self, input: Documents) -> Embeddings:
114121
return []
115122

116123
payload: Dict[str, Union[str, Documents]] = {
117-
"instructions": self.instructions[self.task][ChromaCloudQwenEmbeddingTarget.QUERY],
124+
"instructions": self.instructions[self.task][
125+
ChromaCloudQwenEmbeddingTarget.QUERY
126+
],
118127
"texts": input,
119128
}
120129

@@ -147,34 +156,30 @@ def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
147156
if instructions is not None:
148157
deserialized_instructions = {}
149158
for task_key, targets in instructions.items():
150-
# Convert string key to enum
151-
task_enum = ChromaCloudQwenEmbeddingTask(task_key)
152-
deserialized_instructions[task_enum] = {}
159+
deserialized_instructions[task_key] = {}
153160
for target_key, instruction in targets.items():
154161
# Convert string key to enum
155162
target_enum = ChromaCloudQwenEmbeddingTarget(target_key)
156-
deserialized_instructions[task_enum][target_enum] = instruction
163+
deserialized_instructions[task_key][target_enum] = instruction
164+
deserialized_instructions[task_key][target_enum] = instruction
157165

158166
return ChromaCloudQwenEmbeddingFunction(
159167
model=ChromaCloudQwenEmbeddingModel(model),
160-
task=ChromaCloudQwenEmbeddingTask(task),
168+
task=task,
161169
instructions=deserialized_instructions,
162170
api_key_env_var=api_key_env_var or "CHROMA_API_KEY",
163171
)
164172

165173
def get_config(self) -> Dict[str, Any]:
166174
# Serialize instructions dict with enum keys to string keys for JSON compatibility
167175
serialized_instructions = {
168-
task.value: {
169-
target.value: instruction
170-
for target, instruction in targets.items()
171-
}
176+
task: {target.value: instruction for target, instruction in targets.items()}
172177
for task, targets in self.instructions.items()
173178
}
174179
return {
175180
"api_key_env_var": self.api_key_env_var,
176181
"model": self.model.value,
177-
"task": self.task.value,
182+
"task": self.task,
178183
"instructions": serialized_instructions,
179184
}
180185

@@ -192,7 +197,7 @@ def validate_config_update(
192197
elif "instructions" in new_config:
193198
raise ValueError(
194199
"The instructions cannot be changed after the embedding function has been initialized."
195-
)
200+
)
196201

197202
@staticmethod
198203
def validate_config(config: Dict[str, Any]) -> None:
@@ -205,4 +210,4 @@ def validate_config(config: Dict[str, Any]) -> None:
205210
Raises:
206211
ValidationError: If the configuration does not match the schema
207212
"""
208-
validate_config_schema(config, "chroma-cloud-qwen")
213+
validate_config_schema(config, "chroma-cloud-qwen")

clients/new-js/packages/ai-embeddings/chroma-cloud-qwen/src/index.test.ts

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ import {
22
CHROMA_CLOUD_QWEN_DEFAULT_INSTRUCTIONS,
33
ChromaCloudQwenEmbeddingFunction,
44
ChromaCloudQwenEmbeddingModel,
5-
ChromaCloudQwenEmbeddingTask,
65
} from "./index";
76
import { beforeEach, describe, expect, it, jest } from "@jest/globals";
87

@@ -13,12 +12,11 @@ describe("ChromaCloudQwenEmbeddingFunction", () => {
1312

1413
const defaultParametersTest = "should initialize with default parameters";
1514
if (!process.env.CHROMA_API_KEY) {
16-
it.skip(defaultParametersTest, () => {});
15+
it.skip(defaultParametersTest, () => { });
1716
} else {
1817
it(defaultParametersTest, () => {
1918
const embedder = new ChromaCloudQwenEmbeddingFunction({
2019
model: ChromaCloudQwenEmbeddingModel.QWEN3_EMBEDDING_0p6B,
21-
task: ChromaCloudQwenEmbeddingTask.NL_TO_CODE,
2220
});
2321
expect(embedder.name).toBe("chroma-cloud-qwen");
2422

@@ -37,7 +35,6 @@ describe("ChromaCloudQwenEmbeddingFunction", () => {
3735
expect(() => {
3836
new ChromaCloudQwenEmbeddingFunction({
3937
model: ChromaCloudQwenEmbeddingModel.QWEN3_EMBEDDING_0p6B,
40-
task: ChromaCloudQwenEmbeddingTask.NL_TO_CODE,
4138
});
4239
}).toThrow("Chroma Embedding API key is required");
4340
} finally {
@@ -53,7 +50,6 @@ describe("ChromaCloudQwenEmbeddingFunction", () => {
5350
try {
5451
const embedder = new ChromaCloudQwenEmbeddingFunction({
5552
model: ChromaCloudQwenEmbeddingModel.QWEN3_EMBEDDING_0p6B,
56-
task: ChromaCloudQwenEmbeddingTask.NL_TO_CODE,
5753
apiKeyEnvVar: "CUSTOM_CHROMA_API_KEY",
5854
});
5955

@@ -67,14 +63,14 @@ describe("ChromaCloudQwenEmbeddingFunction", () => {
6763

6864
const buildFromConfigTest = "should build from config";
6965
if (!process.env.CHROMA_API_KEY) {
70-
it.skip(buildFromConfigTest, () => {});
66+
it.skip(buildFromConfigTest, () => { });
7167
} else {
7268
it(buildFromConfigTest, () => {
7369
const config = {
7470
api_key_env_var: "CHROMA_API_KEY",
7571
model: ChromaCloudQwenEmbeddingModel.QWEN3_EMBEDDING_0p6B,
76-
task: ChromaCloudQwenEmbeddingTask.NL_TO_CODE,
7772
instructions: CHROMA_CLOUD_QWEN_DEFAULT_INSTRUCTIONS,
73+
task: "nl_to_code",
7874
};
7975

8076
const embedder = ChromaCloudQwenEmbeddingFunction.buildFromConfig(config);
@@ -84,12 +80,11 @@ describe("ChromaCloudQwenEmbeddingFunction", () => {
8480

8581
const generateEmbeddingsTest = "should generate embeddings";
8682
if (!process.env.CHROMA_API_KEY) {
87-
it.skip(generateEmbeddingsTest, () => {});
83+
it.skip(generateEmbeddingsTest, () => { });
8884
} else {
8985
it(generateEmbeddingsTest, async () => {
9086
const embedder = new ChromaCloudQwenEmbeddingFunction({
9187
model: ChromaCloudQwenEmbeddingModel.QWEN3_EMBEDDING_0p6B,
92-
task: ChromaCloudQwenEmbeddingTask.NL_TO_CODE,
9388
});
9489
const texts = ["Hello world", "Test text"];
9590
const embeddings = await embedder.generate(texts);

clients/new-js/packages/ai-embeddings/chroma-cloud-qwen/src/index.ts

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@ import {
1111

1212
const NAME = "chroma-cloud-qwen";
1313

14+
1415
export interface ChromaCloudQwenConfig {
1516
model: ChromaCloudQwenEmbeddingModel;
16-
task: ChromaCloudQwenEmbeddingTask;
17+
task: string;
1718
instructions: ChromaCloudQwenEmbeddingInstructions;
1819
api_key_env_var: string;
1920
}
@@ -22,33 +23,29 @@ export enum ChromaCloudQwenEmbeddingModel {
2223
QWEN3_EMBEDDING_0p6B = "Qwen/Qwen3-Embedding-0.6B",
2324
}
2425

25-
export enum ChromaCloudQwenEmbeddingTask {
26-
NL_TO_CODE = "nl_to_code", // Queries are in natural language, documents are code
27-
}
28-
2926
export enum ChromaCloudQwenEmbeddingTarget {
3027
DOCUMENTS = "documents",
3128
QUERY = "query",
3229
}
3330

3431
export type ChromaCloudQwenEmbeddingInstructions = Record<
35-
ChromaCloudQwenEmbeddingTask,
32+
string,
3633
Record<ChromaCloudQwenEmbeddingTarget, string>
3734
>;
3835

3936
export const CHROMA_CLOUD_QWEN_DEFAULT_INSTRUCTIONS: ChromaCloudQwenEmbeddingInstructions =
40-
{
41-
[ChromaCloudQwenEmbeddingTask.NL_TO_CODE]: {
42-
[ChromaCloudQwenEmbeddingTarget.DOCUMENTS]: "",
43-
[ChromaCloudQwenEmbeddingTarget.QUERY]:
44-
// Taken from https://github.com/QwenLM/Qwen3-Embedding/blob/main/evaluation/task_prompts.json
45-
"Given a question about coding, retrieval code or passage that can solve user's question",
46-
},
47-
};
37+
{
38+
"nl_to_code": {
39+
[ChromaCloudQwenEmbeddingTarget.DOCUMENTS]: "",
40+
[ChromaCloudQwenEmbeddingTarget.QUERY]:
41+
// Taken from https://github.com/QwenLM/Qwen3-Embedding/blob/main/evaluation/task_prompts.json
42+
"Given a question about coding, retrieval code or passage that can solve user's question",
43+
},
44+
};
4845

4946
export interface ChromaCloudQwenArgs {
5047
model: ChromaCloudQwenEmbeddingModel;
51-
task: ChromaCloudQwenEmbeddingTask;
48+
task?: string;
5249
instructions?: ChromaCloudQwenEmbeddingInstructions;
5350
apiKeyEnvVar?: string;
5451
}
@@ -70,13 +67,13 @@ export class ChromaCloudQwenEmbeddingFunction implements EmbeddingFunction {
7067
private readonly model: ChromaCloudQwenEmbeddingModel;
7168
private readonly url: string;
7269
private readonly headers: { [key: string]: string };
73-
private readonly task: ChromaCloudQwenEmbeddingTask;
70+
private readonly task: string;
7471
private readonly instructions: ChromaCloudQwenEmbeddingInstructions;
7572

7673
constructor(args: ChromaCloudQwenArgs) {
7774
const {
7875
model,
79-
task,
76+
task = "nl_to_code",
8077
instructions = CHROMA_CLOUD_QWEN_DEFAULT_INSTRUCTIONS,
8178
apiKeyEnvVar = "CHROMA_API_KEY",
8279
} = args;
@@ -185,16 +182,14 @@ export class ChromaCloudQwenEmbeddingFunction implements EmbeddingFunction {
185182
if (config.instructions) {
186183
deserializedInstructions = {} as ChromaCloudQwenEmbeddingInstructions;
187184
for (const [taskKey, targets] of Object.entries(config.instructions)) {
188-
// taskKey is the enum value string like "nl_to_code"
189-
const taskEnum = taskKey as ChromaCloudQwenEmbeddingTask;
190-
deserializedInstructions[taskEnum] = {} as Record<
185+
deserializedInstructions[taskKey] = {} as Record<
191186
ChromaCloudQwenEmbeddingTarget,
192187
string
193188
>;
194189
for (const [targetKey, instruction] of Object.entries(targets)) {
195190
// targetKey is the enum value string like "documents" or "query"
196191
const targetEnum = targetKey as ChromaCloudQwenEmbeddingTarget;
197-
deserializedInstructions[taskEnum][targetEnum] = instruction;
192+
deserializedInstructions[taskKey][targetEnum] = instruction;
198193
}
199194
}
200195
} else {

0 commit comments

Comments
 (0)