Skip to content

Commit 96dabb5

Browse files
committed
fix: make embedding model ID matching case-insensitive
- Updated getModelDimension, getModelScoreThreshold, and getModelQueryPrefix to perform case-insensitive matching - Fixes issue where custom providers using different casing (e.g., Qwen/Qwen3-Embedding-8B) were not recognized - Added comprehensive tests for case-insensitive model matching Fixes #9026
1 parent 8e4b145 commit 96dabb5

File tree

2 files changed

+184
-6
lines changed

2 files changed

+184
-6
lines changed
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
import { describe, it, expect } from "vitest"
2+
import {
3+
getModelDimension,
4+
getModelScoreThreshold,
5+
getModelQueryPrefix,
6+
getDefaultModelId,
7+
EMBEDDING_MODEL_PROFILES,
8+
} from "../embeddingModels"
9+
10+
describe("embeddingModels", () => {
11+
describe("getModelDimension", () => {
12+
it("should return the correct dimension for a valid model", () => {
13+
expect(getModelDimension("openai", "text-embedding-3-small")).toBe(1536)
14+
expect(getModelDimension("openai", "text-embedding-3-large")).toBe(3072)
15+
expect(getModelDimension("openrouter", "qwen/qwen3-embedding-8b")).toBe(4096)
16+
})
17+
18+
it("should be case-insensitive for model IDs", () => {
19+
// Test with different case variations
20+
expect(getModelDimension("openai", "TEXT-EMBEDDING-3-SMALL")).toBe(1536)
21+
expect(getModelDimension("openai", "Text-Embedding-3-Large")).toBe(3072)
22+
expect(getModelDimension("openrouter", "Qwen/Qwen3-Embedding-8B")).toBe(4096)
23+
expect(getModelDimension("openrouter", "QWEN/QWEN3-EMBEDDING-8B")).toBe(4096)
24+
25+
// Test with mixed case for other providers
26+
expect(getModelDimension("gemini", "TEXT-EMBEDDING-004")).toBe(768)
27+
expect(getModelDimension("mistral", "CODESTRAL-EMBED-2505")).toBe(1536)
28+
})
29+
30+
it("should return undefined for non-existent model", () => {
31+
expect(getModelDimension("openai", "non-existent-model")).toBeUndefined()
32+
})
33+
34+
it("should return undefined for non-existent provider", () => {
35+
// @ts-expect-error Testing with invalid provider
36+
expect(getModelDimension("non-existent-provider", "text-embedding-3-small")).toBeUndefined()
37+
})
38+
39+
it("should handle lowercase model IDs that exist in profiles", () => {
40+
expect(getModelDimension("openai", "text-embedding-ada-002")).toBe(1536)
41+
expect(getModelDimension("ollama", "nomic-embed-text")).toBe(768)
42+
})
43+
})
44+
45+
describe("getModelScoreThreshold", () => {
46+
it("should return the correct score threshold for a valid model", () => {
47+
expect(getModelScoreThreshold("openai", "text-embedding-3-small")).toBe(0.4)
48+
expect(getModelScoreThreshold("ollama", "nomic-embed-code")).toBe(0.15)
49+
expect(getModelScoreThreshold("openrouter", "qwen/qwen3-embedding-8b")).toBe(0.4)
50+
})
51+
52+
it("should be case-insensitive for model IDs", () => {
53+
// Test with different case variations
54+
expect(getModelScoreThreshold("openai", "TEXT-EMBEDDING-3-SMALL")).toBe(0.4)
55+
expect(getModelScoreThreshold("ollama", "NOMIC-EMBED-CODE")).toBe(0.15)
56+
expect(getModelScoreThreshold("openrouter", "Qwen/Qwen3-Embedding-8B")).toBe(0.4)
57+
58+
// Test models without score thresholds
59+
expect(getModelScoreThreshold("gemini", "TEXT-EMBEDDING-004")).toBeUndefined()
60+
})
61+
62+
it("should return undefined for model without score threshold", () => {
63+
expect(getModelScoreThreshold("gemini", "text-embedding-004")).toBeUndefined()
64+
})
65+
66+
it("should return undefined for non-existent model", () => {
67+
expect(getModelScoreThreshold("openai", "non-existent-model")).toBeUndefined()
68+
})
69+
70+
it("should return undefined for non-existent provider", () => {
71+
// @ts-expect-error Testing with invalid provider
72+
expect(getModelScoreThreshold("non-existent-provider", "text-embedding-3-small")).toBeUndefined()
73+
})
74+
})
75+
76+
describe("getModelQueryPrefix", () => {
77+
it("should return the correct query prefix for a model that has one", () => {
78+
expect(getModelQueryPrefix("ollama", "nomic-embed-code")).toBe(
79+
"Represent this query for searching relevant code: ",
80+
)
81+
})
82+
83+
it("should be case-insensitive for model IDs", () => {
84+
// Test with different case variations
85+
expect(getModelQueryPrefix("ollama", "NOMIC-EMBED-CODE")).toBe(
86+
"Represent this query for searching relevant code: ",
87+
)
88+
expect(getModelQueryPrefix("ollama", "Nomic-Embed-Code")).toBe(
89+
"Represent this query for searching relevant code: ",
90+
)
91+
expect(getModelQueryPrefix("openai-compatible", "NOMIC-EMBED-CODE")).toBe(
92+
"Represent this query for searching relevant code: ",
93+
)
94+
})
95+
96+
it("should return undefined for model without query prefix", () => {
97+
expect(getModelQueryPrefix("openai", "text-embedding-3-small")).toBeUndefined()
98+
expect(getModelQueryPrefix("gemini", "text-embedding-004")).toBeUndefined()
99+
})
100+
101+
it("should return undefined for non-existent model", () => {
102+
expect(getModelQueryPrefix("ollama", "non-existent-model")).toBeUndefined()
103+
})
104+
105+
it("should return undefined for non-existent provider", () => {
106+
// @ts-expect-error Testing with invalid provider
107+
expect(getModelQueryPrefix("non-existent-provider", "nomic-embed-code")).toBeUndefined()
108+
})
109+
})
110+
111+
describe("getDefaultModelId", () => {
112+
it("should return the correct default model for each provider", () => {
113+
expect(getDefaultModelId("openai")).toBe("text-embedding-3-small")
114+
expect(getDefaultModelId("openai-compatible")).toBe("text-embedding-3-small")
115+
expect(getDefaultModelId("gemini")).toBe("gemini-embedding-001")
116+
expect(getDefaultModelId("mistral")).toBe("codestral-embed-2505")
117+
expect(getDefaultModelId("vercel-ai-gateway")).toBe("openai/text-embedding-3-large")
118+
expect(getDefaultModelId("openrouter")).toBe("openai/text-embedding-3-large")
119+
})
120+
121+
it("should return a default for Ollama", () => {
122+
const defaultModel = getDefaultModelId("ollama")
123+
expect(defaultModel).toBeDefined()
124+
expect(EMBEDDING_MODEL_PROFILES.ollama?.[defaultModel]).toBeDefined()
125+
})
126+
127+
it("should return fallback for unknown provider", () => {
128+
// @ts-expect-error Testing with invalid provider
129+
expect(getDefaultModelId("unknown-provider")).toBe("text-embedding-3-small")
130+
})
131+
})
132+
133+
describe("Qwen model specific tests", () => {
134+
it("should handle Qwen model with original casing", () => {
135+
expect(getModelDimension("openrouter", "qwen/qwen3-embedding-8b")).toBe(4096)
136+
expect(getModelScoreThreshold("openrouter", "qwen/qwen3-embedding-8b")).toBe(0.4)
137+
})
138+
139+
it("should handle Qwen model with user's casing from issue", () => {
140+
// This is the exact casing from the user's issue
141+
expect(getModelDimension("openrouter", "Qwen/Qwen3-Embedding-8B")).toBe(4096)
142+
expect(getModelScoreThreshold("openrouter", "Qwen/Qwen3-Embedding-8B")).toBe(0.4)
143+
})
144+
145+
it("should handle Qwen model with all uppercase", () => {
146+
expect(getModelDimension("openrouter", "QWEN/QWEN3-EMBEDDING-8B")).toBe(4096)
147+
expect(getModelScoreThreshold("openrouter", "QWEN/QWEN3-EMBEDDING-8B")).toBe(0.4)
148+
})
149+
150+
it("should handle Qwen model with random casing", () => {
151+
expect(getModelDimension("openrouter", "qWeN/QwEn3-EmBeDdInG-8b")).toBe(4096)
152+
expect(getModelScoreThreshold("openrouter", "qWeN/QwEn3-EmBeDdInG-8b")).toBe(0.4)
153+
})
154+
})
155+
})

src/shared/embeddingModels.ts

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,19 @@ export function getModelDimension(provider: EmbedderProvider, modelId: string):
105105
return undefined
106106
}
107107

108-
const modelProfile = providerProfiles[modelId]
108+
// Convert modelId to lowercase for case-insensitive comparison
109+
const lowerModelId = modelId.toLowerCase()
110+
111+
// Find the model profile with case-insensitive matching
112+
const modelProfile = Object.keys(providerProfiles).find((key) => key.toLowerCase() === lowerModelId)
113+
109114
if (!modelProfile) {
110115
// Don't warn here, as it might be a custom model ID not in our profiles
111116
// console.warn(`Model not found for provider ${provider}: ${modelId}`)
112117
return undefined // Or potentially return a default/fallback dimension?
113118
}
114119

115-
return modelProfile.dimension
120+
return providerProfiles[modelProfile].dimension
116121
}
117122

118123
/**
@@ -127,8 +132,17 @@ export function getModelScoreThreshold(provider: EmbedderProvider, modelId: stri
127132
return undefined
128133
}
129134

130-
const modelProfile = providerProfiles[modelId]
131-
return modelProfile?.scoreThreshold
135+
// Convert modelId to lowercase for case-insensitive comparison
136+
const lowerModelId = modelId.toLowerCase()
137+
138+
// Find the model profile with case-insensitive matching
139+
const modelProfileKey = Object.keys(providerProfiles).find((key) => key.toLowerCase() === lowerModelId)
140+
141+
if (!modelProfileKey) {
142+
return undefined
143+
}
144+
145+
return providerProfiles[modelProfileKey]?.scoreThreshold
132146
}
133147

134148
/**
@@ -143,8 +157,17 @@ export function getModelQueryPrefix(provider: EmbedderProvider, modelId: string)
143157
return undefined
144158
}
145159

146-
const modelProfile = providerProfiles[modelId]
147-
return modelProfile?.queryPrefix
160+
// Convert modelId to lowercase for case-insensitive comparison
161+
const lowerModelId = modelId.toLowerCase()
162+
163+
// Find the model profile with case-insensitive matching
164+
const modelProfileKey = Object.keys(providerProfiles).find((key) => key.toLowerCase() === lowerModelId)
165+
166+
if (!modelProfileKey) {
167+
return undefined
168+
}
169+
170+
return providerProfiles[modelProfileKey]?.queryPrefix
148171
}
149172

150173
/**

0 commit comments

Comments
 (0)