Skip to content

Commit 49e775a

Browse files
committed
fix: support full endpoint URLs in OpenAI Compatible provider (#5212)
1 parent 8bb7ef0 commit 49e775a

File tree

2 files changed

+229
-8
lines changed

2 files changed

+229
-8
lines changed

src/services/code-index/embedders/__tests__/openai-compatible.spec.ts

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ import { MAX_ITEM_TOKENS, INITIAL_RETRY_DELAY_MS } from "../../constants"
66
// Mock the OpenAI SDK
77
vitest.mock("openai")
88

9+
// Mock global fetch
10+
global.fetch = vitest.fn()
11+
912
// Mock i18n
1013
vitest.mock("../../../../i18n", () => ({
1114
t: (key: string, params?: Record<string, any>) => {
@@ -613,5 +616,159 @@ describe("OpenAICompatibleEmbedder", () => {
613616
expect(returnedArray).toEqual([0.25, 0.5, 0.75, 1.0])
614617
})
615618
})
619+
620+
/**
621+
* Test Azure OpenAI compatibility with helper functions for conciseness
622+
*/
623+
describe("Azure OpenAI compatibility", () => {
624+
const azureUrl =
625+
"https://myresource.openai.azure.com/openai/deployments/mymodel/embeddings?api-version=2024-02-01"
626+
const baseUrl = "https://api.openai.com/v1"
627+
628+
// Helper to create mock fetch response
629+
const createMockResponse = (data: any, status = 200, ok = true) => ({
630+
ok,
631+
status,
632+
json: vitest.fn().mockResolvedValue(data),
633+
text: vitest.fn().mockResolvedValue(status === 200 ? "" : "Error message"),
634+
})
635+
636+
// Helper to create base64 embedding
637+
const createBase64Embedding = (values: number[]) => {
638+
const embedding = new Float32Array(values)
639+
return Buffer.from(embedding.buffer).toString("base64")
640+
}
641+
642+
// Helper to verify embedding values with floating-point tolerance
643+
const expectEmbeddingValues = (actual: number[], expected: number[]) => {
644+
expect(actual).toHaveLength(expected.length)
645+
expected.forEach((val, i) => expect(actual[i]).toBeCloseTo(val, 5))
646+
}
647+
648+
beforeEach(() => {
649+
vitest.clearAllMocks()
650+
;(global.fetch as MockedFunction<typeof fetch>).mockReset()
651+
})
652+
653+
describe("URL detection", () => {
654+
it.each([
655+
[
656+
"https://myresource.openai.azure.com/openai/deployments/mymodel/embeddings?api-version=2024-02-01",
657+
true,
658+
],
659+
["https://myresource.openai.azure.com/openai/deployments/text-embedding-ada-002/embeddings", true],
660+
["https://api.openai.com/v1", false],
661+
["https://api.example.com", false],
662+
["http://localhost:8080", false],
663+
])("should detect URL type correctly: %s -> %s", (url, expected) => {
664+
const embedder = new OpenAICompatibleEmbedder(url, testApiKey, testModelId)
665+
const isFullUrl = (embedder as any).isFullEndpointUrl(url)
666+
expect(isFullUrl).toBe(expected)
667+
})
668+
})
669+
670+
describe("direct HTTP requests", () => {
671+
it("should use direct fetch for Azure URLs and SDK for base URLs", async () => {
672+
const testTexts = ["Test text"]
673+
const base64String = createBase64Embedding([0.1, 0.2, 0.3])
674+
675+
// Test Azure URL (direct fetch)
676+
const azureEmbedder = new OpenAICompatibleEmbedder(azureUrl, testApiKey, testModelId)
677+
const mockFetchResponse = createMockResponse({
678+
data: [{ embedding: base64String }],
679+
usage: { prompt_tokens: 10, total_tokens: 15 },
680+
})
681+
;(global.fetch as MockedFunction<typeof fetch>).mockResolvedValue(mockFetchResponse as any)
682+
683+
const azureResult = await azureEmbedder.createEmbeddings(testTexts)
684+
expect(global.fetch).toHaveBeenCalledWith(
685+
azureUrl,
686+
expect.objectContaining({
687+
method: "POST",
688+
headers: expect.objectContaining({
689+
"api-key": testApiKey,
690+
Authorization: `Bearer ${testApiKey}`,
691+
}),
692+
}),
693+
)
694+
expect(mockEmbeddingsCreate).not.toHaveBeenCalled()
695+
expectEmbeddingValues(azureResult.embeddings[0], [0.1, 0.2, 0.3])
696+
697+
// Reset and test base URL (SDK)
698+
vitest.clearAllMocks()
699+
const baseEmbedder = new OpenAICompatibleEmbedder(baseUrl, testApiKey, testModelId)
700+
mockEmbeddingsCreate.mockResolvedValue({
701+
data: [{ embedding: [0.4, 0.5, 0.6] }],
702+
usage: { prompt_tokens: 10, total_tokens: 15 },
703+
})
704+
705+
const baseResult = await baseEmbedder.createEmbeddings(testTexts)
706+
expect(mockEmbeddingsCreate).toHaveBeenCalled()
707+
expect(global.fetch).not.toHaveBeenCalled()
708+
expect(baseResult.embeddings[0]).toEqual([0.4, 0.5, 0.6])
709+
})
710+
711+
it.each([
712+
[401, "Authentication failed. Please check your API key."],
713+
[500, "Failed to create embeddings after 3 attempts"],
714+
])("should handle HTTP errors: %d", async (status, expectedMessage) => {
715+
const embedder = new OpenAICompatibleEmbedder(azureUrl, testApiKey, testModelId)
716+
const mockResponse = createMockResponse({}, status, false)
717+
;(global.fetch as MockedFunction<typeof fetch>).mockResolvedValue(mockResponse as any)
718+
719+
await expect(embedder.createEmbeddings(["test"])).rejects.toThrow(expectedMessage)
720+
})
721+
722+
it("should handle rate limiting with retries", async () => {
723+
vitest.useFakeTimers()
724+
const embedder = new OpenAICompatibleEmbedder(azureUrl, testApiKey, testModelId)
725+
const base64String = createBase64Embedding([0.1, 0.2, 0.3])
726+
727+
;(global.fetch as MockedFunction<typeof fetch>)
728+
.mockResolvedValueOnce(createMockResponse({}, 429, false) as any)
729+
.mockResolvedValueOnce(createMockResponse({}, 429, false) as any)
730+
.mockResolvedValueOnce(
731+
createMockResponse({
732+
data: [{ embedding: base64String }],
733+
usage: { prompt_tokens: 10, total_tokens: 15 },
734+
}) as any,
735+
)
736+
737+
const resultPromise = embedder.createEmbeddings(["test"])
738+
await vitest.advanceTimersByTimeAsync(INITIAL_RETRY_DELAY_MS * 3)
739+
const result = await resultPromise
740+
741+
expect(global.fetch).toHaveBeenCalledTimes(3)
742+
expect(console.warn).toHaveBeenCalledWith(expect.stringContaining("Rate limit hit"))
743+
expectEmbeddingValues(result.embeddings[0], [0.1, 0.2, 0.3])
744+
vitest.useRealTimers()
745+
})
746+
747+
it("should handle multiple embeddings and network errors", async () => {
748+
const embedder = new OpenAICompatibleEmbedder(azureUrl, testApiKey, testModelId)
749+
750+
// Test multiple embeddings
751+
const base64_1 = createBase64Embedding([0.25, 0.5])
752+
const base64_2 = createBase64Embedding([0.75, 1.0])
753+
const mockResponse = createMockResponse({
754+
data: [{ embedding: base64_1 }, { embedding: base64_2 }],
755+
usage: { prompt_tokens: 20, total_tokens: 30 },
756+
})
757+
;(global.fetch as MockedFunction<typeof fetch>).mockResolvedValue(mockResponse as any)
758+
759+
const result = await embedder.createEmbeddings(["test1", "test2"])
760+
expect(result.embeddings).toHaveLength(2)
761+
expectEmbeddingValues(result.embeddings[0], [0.25, 0.5])
762+
expectEmbeddingValues(result.embeddings[1], [0.75, 1.0])
763+
764+
// Test network error
765+
const networkError = new Error("Network failed")
766+
;(global.fetch as MockedFunction<typeof fetch>).mockRejectedValue(networkError)
767+
await expect(embedder.createEmbeddings(["test"])).rejects.toThrow(
768+
"Failed to create embeddings after 3 attempts",
769+
)
770+
})
771+
})
772+
})
616773
})
617774
})

src/services/code-index/embedders/openai-compatible.ts

Lines changed: 72 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ interface OpenAIEmbeddingResponse {
2929
export class OpenAICompatibleEmbedder implements IEmbedder {
3030
private embeddingsClient: OpenAI
3131
private readonly defaultModelId: string
32+
private readonly baseUrl: string
33+
private readonly apiKey: string
3234

3335
/**
3436
* Creates a new OpenAI Compatible embedder
@@ -44,6 +46,8 @@ export class OpenAICompatibleEmbedder implements IEmbedder {
4446
throw new Error("API key is required for OpenAI Compatible embedder")
4547
}
4648

49+
this.baseUrl = baseUrl
50+
this.apiKey = apiKey
4751
this.embeddingsClient = new OpenAI({
4852
baseURL: baseUrl,
4953
apiKey: apiKey,
@@ -109,6 +113,56 @@ export class OpenAICompatibleEmbedder implements IEmbedder {
109113
return { embeddings: allEmbeddings, usage }
110114
}
111115

116+
/**
117+
* Determines if the provided URL is a full endpoint URL (contains /embeddings or /deployments/)
118+
* or a base URL that needs the endpoint appended by the SDK
119+
* @param url The URL to check
120+
* @returns true if it's a full endpoint URL, false if it's a base URL
121+
*/
122+
private isFullEndpointUrl(url: string): boolean {
123+
// Check if the URL contains common embedding endpoint patterns
124+
return url.includes("/embeddings") || url.includes("/deployments/")
125+
}
126+
127+
/**
128+
* Makes a direct HTTP request to the embeddings endpoint
129+
* Used when the user provides a full endpoint URL (e.g., Azure OpenAI with query parameters)
130+
* @param url The full endpoint URL
131+
* @param batchTexts Array of texts to embed
132+
* @param model Model identifier to use
133+
* @returns Promise resolving to OpenAI-compatible response
134+
*/
135+
private async makeDirectEmbeddingRequest(
136+
url: string,
137+
batchTexts: string[],
138+
model: string,
139+
): Promise<OpenAIEmbeddingResponse> {
140+
const response = await fetch(url, {
141+
method: "POST",
142+
headers: {
143+
"Content-Type": "application/json",
144+
// Azure OpenAI uses 'api-key' header, while OpenAI uses 'Authorization'
145+
// We'll try 'api-key' first for Azure compatibility
146+
"api-key": this.apiKey,
147+
Authorization: `Bearer ${this.apiKey}`,
148+
},
149+
body: JSON.stringify({
150+
input: batchTexts,
151+
model: model,
152+
encoding_format: "base64",
153+
}),
154+
})
155+
156+
if (!response.ok) {
157+
const errorText = await response.text()
158+
const error: any = new Error(`HTTP ${response.status}: ${errorText}`)
159+
error.status = response.status
160+
throw error
161+
}
162+
163+
return await response.json()
164+
}
165+
112166
/**
113167
* Helper method to handle batch embedding with retries and exponential backoff
114168
* @param batchTexts Array of texts to embed in this batch
@@ -119,16 +173,26 @@ export class OpenAICompatibleEmbedder implements IEmbedder {
119173
batchTexts: string[],
120174
model: string,
121175
): Promise<{ embeddings: number[][]; usage: { promptTokens: number; totalTokens: number } }> {
176+
const isFullUrl = this.isFullEndpointUrl(this.baseUrl)
177+
122178
for (let attempts = 0; attempts < MAX_RETRIES; attempts++) {
123179
try {
124-
const response = (await this.embeddingsClient.embeddings.create({
125-
input: batchTexts,
126-
model: model,
127-
// OpenAI package (as of v4.78.1) has a parsing issue that truncates embedding dimensions to 256
128-
// when processing numeric arrays, which breaks compatibility with models using larger dimensions.
129-
// By requesting base64 encoding, we bypass the package's parser and handle decoding ourselves.
130-
encoding_format: "base64",
131-
})) as OpenAIEmbeddingResponse
180+
let response: OpenAIEmbeddingResponse
181+
182+
if (isFullUrl) {
183+
// Use direct HTTP request for full endpoint URLs
184+
response = await this.makeDirectEmbeddingRequest(this.baseUrl, batchTexts, model)
185+
} else {
186+
// Use OpenAI SDK for base URLs
187+
response = (await this.embeddingsClient.embeddings.create({
188+
input: batchTexts,
189+
model: model,
190+
// OpenAI package (as of v4.78.1) has a parsing issue that truncates embedding dimensions to 256
191+
// when processing numeric arrays, which breaks compatibility with models using larger dimensions.
192+
// By requesting base64 encoding, we bypass the package's parser and handle decoding ourselves.
193+
encoding_format: "base64",
194+
})) as OpenAIEmbeddingResponse
195+
}
132196

133197
// Convert base64 embeddings to float32 arrays
134198
const processedEmbeddings = response.data.map((item: EmbeddingItem) => {

0 commit comments

Comments
 (0)