Skip to content

Commit c739f34

Browse files
authored
feat: support output dimensions for embedding models (#490) (#522)
Add optional outputDimension field for embedding models to request specific dimensions from providers that support Matryoshka Representation Learning (MRL). Supported providers: - OpenAI text-embedding-3-* (via 'dimensions' parameter) - Google gemini-embedding-001 (via 'config.outputDimensionality') Changes: - Add outputDimension field to EmbeddingModel type - Update all provider getEmbedding methods to accept dimensions option - Add Output Dimensions input field in Add Embedding Model modal - Validate that returned dimension matches requested outputDimension
1 parent 8d1ce6e commit c739f34

19 files changed

+127
-17
lines changed

src/components/settings/modals/AddEmbeddingModelModal.tsx

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ function AddEmbeddingModelModalComponent({
4343
providerType: DEFAULT_PROVIDERS[0].type,
4444
id: '',
4545
model: '',
46+
outputDimension: undefined,
4647
})
48+
const [outputDimensionInput, setOutputDimensionInput] = useState('')
4749

4850
const handleSubmit = async () => {
4951
try {
@@ -69,6 +71,7 @@ function AddEmbeddingModelModalComponent({
6971
const embeddingResult = await providerClient.getEmbedding(
7072
formData.model,
7173
'test',
74+
{ dimensions: formData.outputDimension },
7275
)
7376

7477
if (!Array.isArray(embeddingResult) || embeddingResult.length === 0) {
@@ -77,6 +80,18 @@ function AddEmbeddingModelModalComponent({
7780

7881
const dimension = embeddingResult.length
7982

83+
// Validate that the model respected the requested output dimension
84+
if (
85+
formData.outputDimension !== undefined &&
86+
dimension !== formData.outputDimension
87+
) {
88+
throw new Error(
89+
`Requested output dimension ${formData.outputDimension}, but the model returned ${dimension} dimensions. ` +
90+
`This model may not support custom output dimensions (Matryoshka Representation Learning). ` +
91+
`Leave the "Output Dimensions" field empty to use the model's default dimension.`,
92+
)
93+
}
94+
8095
if (!supportedDimensionsForIndex.includes(dimension)) {
8196
const confirmed = await new Promise<boolean>((resolve) => {
8297
new ConfirmModal(plugin.app, {
@@ -174,6 +189,24 @@ function AddEmbeddingModelModalComponent({
174189
/>
175190
</ObsidianSetting>
176191

192+
<ObsidianSetting
193+
name="Output Dimensions"
194+
desc="Optional. Request a specific output dimension from models that support Matryoshka Representation Learning (MRL), such as OpenAI's text-embedding-3-* or Google's gemini-embedding-001. Leave empty to use the model's default dimension."
195+
>
196+
<ObsidianTextInput
197+
value={outputDimensionInput}
198+
placeholder="e.g., 768"
199+
onChange={(value: string) => {
200+
setOutputDimensionInput(value)
201+
const parsed = parseInt(value, 10)
202+
setFormData((prev) => ({
203+
...prev,
204+
outputDimension: isNaN(parsed) ? undefined : parsed,
205+
}))
206+
}}
207+
/>
208+
</ObsidianSetting>
209+
177210
<ObsidianSetting>
178211
<ObsidianButton text="Add" onClick={handleSubmit} cta />
179212
<ObsidianButton text="Cancel" onClick={onClose} />

src/core/llm/anthropic.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,11 @@ https://github.com/glowingjade/obsidian-smart-composer/issues/286`,
603603
throw new Error(`Unsupported tool choice: ${JSON.stringify(toolChoice)}`)
604604
}
605605

606-
async getEmbedding(_model: string, _text: string): Promise<number[]> {
606+
async getEmbedding(
607+
_model: string,
608+
_text: string,
609+
_options?: { dimensions?: number },
610+
): Promise<number[]> {
607611
throw new Error(
608612
`Provider ${this.provider.id} does not support embeddings. Please use a different provider.`,
609613
)

src/core/llm/anthropicClaudeCodeProvider.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,11 @@ export class AnthropicClaudeCodeProvider extends BaseLLMProvider<
7777
)
7878
}
7979

80-
async getEmbedding(_model: string, _text: string): Promise<number[]> {
80+
async getEmbedding(
81+
_model: string,
82+
_text: string,
83+
_options?: { dimensions?: number },
84+
): Promise<number[]> {
8185
throw new Error(
8286
`Provider ${this.provider.id} does not support embeddings. Please use a different provider.`,
8387
)

src/core/llm/azureOpenaiProvider.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,11 @@ export class AzureOpenAIProvider extends BaseLLMProvider<
5757
return this.adapter.streamResponse(this.client, request, options)
5858
}
5959

60-
async getEmbedding(_model: string, _text: string): Promise<number[]> {
60+
async getEmbedding(
61+
_model: string,
62+
_text: string,
63+
_options?: { dimensions?: number },
64+
): Promise<number[]> {
6165
throw new Error(
6266
`Provider ${this.provider.id} does not support embeddings. Please use a different provider.`,
6367
)

src/core/llm/base.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,9 @@ export abstract class BaseLLMProvider<P extends LLMProvider> {
2929
options?: LLMOptions,
3030
): Promise<AsyncIterable<LLMResponseStreaming>>
3131

32-
abstract getEmbedding(model: string, text: string): Promise<number[]>
32+
abstract getEmbedding(
33+
model: string,
34+
text: string,
35+
options?: { dimensions?: number },
36+
): Promise<number[]>
3337
}

src/core/llm/deepseekStudioProvider.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,11 @@ export class DeepSeekStudioProvider extends BaseLLMProvider<
8080
return this.adapter.streamResponse(this.client, formattedRequest, options)
8181
}
8282

83-
async getEmbedding(_model: string, _text: string): Promise<number[]> {
83+
async getEmbedding(
84+
_model: string,
85+
_text: string,
86+
_options?: { dimensions?: number },
87+
): Promise<number[]> {
8488
throw new Error(
8589
`Provider ${this.provider.id} does not support embeddings. Please use a different provider.`,
8690
)

src/core/llm/gemini.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,11 @@ export class GeminiProvider extends BaseLLMProvider<
543543
return config
544544
}
545545

546-
async getEmbedding(model: string, text: string): Promise<number[]> {
546+
async getEmbedding(
547+
model: string,
548+
text: string,
549+
options?: { dimensions?: number },
550+
): Promise<number[]> {
547551
if (!this.apiKey) {
548552
throw new LLMAPIKeyNotSetException(
549553
`Provider ${this.provider.id} API key is missing. Please set it in settings menu.`,
@@ -554,6 +558,9 @@ export class GeminiProvider extends BaseLLMProvider<
554558
const response = await this.client.models.embedContent({
555559
model: model,
556560
contents: text,
561+
...(options?.dimensions && {
562+
config: { outputDimensionality: options.dimensions },
563+
}),
557564
})
558565
return response.embeddings?.[0]?.values ?? []
559566
} catch (error) {

src/core/llm/geminiPlanProvider.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,11 @@ export class GeminiPlanProvider extends BaseLLMProvider<
7979
)
8080
}
8181

82-
async getEmbedding(_model: string, _text: string): Promise<number[]> {
82+
async getEmbedding(
83+
_model: string,
84+
_text: string,
85+
_options?: { dimensions?: number },
86+
): Promise<number[]> {
8387
throw new Error(
8488
`Provider ${this.provider.id} does not support embeddings. Please use a different provider.`,
8589
)

src/core/llm/lmStudioProvider.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,16 @@ export class LmStudioProvider extends BaseLLMProvider<
5555
return this.adapter.streamResponse(this.client, request, options)
5656
}
5757

58-
async getEmbedding(model: string, text: string): Promise<number[]> {
58+
async getEmbedding(
59+
model: string,
60+
text: string,
61+
options?: { dimensions?: number },
62+
): Promise<number[]> {
5963
const embedding = await this.client.embeddings.create({
6064
model: model,
6165
input: text,
6266
encoding_format: 'float',
67+
...(options?.dimensions && { dimensions: options.dimensions }),
6368
})
6469
return embedding.data[0].embedding
6570
}

src/core/llm/mistralProvider.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,11 @@ export class MistralProvider extends BaseLLMProvider<
5656
return this.adapter.streamResponse(this.client, request, options)
5757
}
5858

59-
async getEmbedding(_model: string, _text: string): Promise<number[]> {
59+
async getEmbedding(
60+
_model: string,
61+
_text: string,
62+
_options?: { dimensions?: number },
63+
): Promise<number[]> {
6064
throw new Error(
6165
`Provider ${this.provider.id} does not support embeddings. Please use a different provider.`,
6266
)

0 commit comments

Comments
 (0)