Skip to content

Commit e9bcee5

Browse files
authored
feat: Add support for Azure AI Inference Service with DeepSeek-V3 model (#2241)
* feat: Add support for Azure AI Inference Service with DeepSeek-V3 model * refactor: extract Azure AI inference path to constant to avoid duplication * fix(tests): update RequestyHandler tests to properly handle Azure inference and streaming * fix(api): remove duplicate constant and update requesty tests * refactor: remove unused isAzure property from OpenAiHandler * refactor(openai): remove unused isAzure and extract Azure check
1 parent 49a61ca commit e9bcee5

File tree

3 files changed

+217
-37
lines changed

3 files changed

+217
-37
lines changed

src/api/providers/__tests__/openai.test.ts

Lines changed: 115 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import { OpenAiHandler } from "../openai"
22
import { ApiHandlerOptions } from "../../../shared/api"
33
import { Anthropic } from "@anthropic-ai/sdk"
4+
import { DEEP_SEEK_DEFAULT_TEMPERATURE } from "../constants"
45

56
// Mock OpenAI client
67
const mockCreate = jest.fn()
@@ -202,10 +203,13 @@ describe("OpenAiHandler", () => {
202203
it("should complete prompt successfully", async () => {
203204
const result = await handler.completePrompt("Test prompt")
204205
expect(result).toBe("Test response")
205-
expect(mockCreate).toHaveBeenCalledWith({
206-
model: mockOptions.openAiModelId,
207-
messages: [{ role: "user", content: "Test prompt" }],
208-
})
206+
expect(mockCreate).toHaveBeenCalledWith(
207+
{
208+
model: mockOptions.openAiModelId,
209+
messages: [{ role: "user", content: "Test prompt" }],
210+
},
211+
{},
212+
)
209213
})
210214

211215
it("should handle API errors", async () => {
@@ -241,4 +245,111 @@ describe("OpenAiHandler", () => {
241245
expect(model.info).toBeDefined()
242246
})
243247
})
248+
249+
describe("Azure AI Inference Service", () => {
250+
const azureOptions = {
251+
...mockOptions,
252+
openAiBaseUrl: "https://test.services.ai.azure.com",
253+
openAiModelId: "deepseek-v3",
254+
azureApiVersion: "2024-05-01-preview",
255+
}
256+
257+
it("should initialize with Azure AI Inference Service configuration", () => {
258+
const azureHandler = new OpenAiHandler(azureOptions)
259+
expect(azureHandler).toBeInstanceOf(OpenAiHandler)
260+
expect(azureHandler.getModel().id).toBe(azureOptions.openAiModelId)
261+
})
262+
263+
it("should handle streaming responses with Azure AI Inference Service", async () => {
264+
const azureHandler = new OpenAiHandler(azureOptions)
265+
const systemPrompt = "You are a helpful assistant."
266+
const messages: Anthropic.Messages.MessageParam[] = [
267+
{
268+
role: "user",
269+
content: "Hello!",
270+
},
271+
]
272+
273+
const stream = azureHandler.createMessage(systemPrompt, messages)
274+
const chunks: any[] = []
275+
for await (const chunk of stream) {
276+
chunks.push(chunk)
277+
}
278+
279+
expect(chunks.length).toBeGreaterThan(0)
280+
const textChunks = chunks.filter((chunk) => chunk.type === "text")
281+
expect(textChunks).toHaveLength(1)
282+
expect(textChunks[0].text).toBe("Test response")
283+
284+
// Verify the API call was made with correct Azure AI Inference Service path
285+
expect(mockCreate).toHaveBeenCalledWith(
286+
{
287+
model: azureOptions.openAiModelId,
288+
messages: [
289+
{ role: "system", content: systemPrompt },
290+
{ role: "user", content: "Hello!" },
291+
],
292+
stream: true,
293+
stream_options: { include_usage: true },
294+
temperature: 0,
295+
},
296+
{ path: "/models/chat/completions" },
297+
)
298+
})
299+
300+
it("should handle non-streaming responses with Azure AI Inference Service", async () => {
301+
const azureHandler = new OpenAiHandler({
302+
...azureOptions,
303+
openAiStreamingEnabled: false,
304+
})
305+
const systemPrompt = "You are a helpful assistant."
306+
const messages: Anthropic.Messages.MessageParam[] = [
307+
{
308+
role: "user",
309+
content: "Hello!",
310+
},
311+
]
312+
313+
const stream = azureHandler.createMessage(systemPrompt, messages)
314+
const chunks: any[] = []
315+
for await (const chunk of stream) {
316+
chunks.push(chunk)
317+
}
318+
319+
expect(chunks.length).toBeGreaterThan(0)
320+
const textChunk = chunks.find((chunk) => chunk.type === "text")
321+
const usageChunk = chunks.find((chunk) => chunk.type === "usage")
322+
323+
expect(textChunk).toBeDefined()
324+
expect(textChunk?.text).toBe("Test response")
325+
expect(usageChunk).toBeDefined()
326+
expect(usageChunk?.inputTokens).toBe(10)
327+
expect(usageChunk?.outputTokens).toBe(5)
328+
329+
// Verify the API call was made with correct Azure AI Inference Service path
330+
expect(mockCreate).toHaveBeenCalledWith(
331+
{
332+
model: azureOptions.openAiModelId,
333+
messages: [
334+
{ role: "user", content: systemPrompt },
335+
{ role: "user", content: "Hello!" },
336+
],
337+
},
338+
{ path: "/models/chat/completions" },
339+
)
340+
})
341+
342+
it("should handle completePrompt with Azure AI Inference Service", async () => {
343+
const azureHandler = new OpenAiHandler(azureOptions)
344+
const result = await azureHandler.completePrompt("Test prompt")
345+
expect(result).toBe("Test response")
346+
expect(mockCreate).toHaveBeenCalledWith(
347+
{
348+
model: azureOptions.openAiModelId,
349+
messages: [{ role: "user", content: "Test prompt" }],
350+
},
351+
{ path: "/models/chat/completions" },
352+
)
353+
})
354+
})
244355
})

src/api/providers/__tests__/requesty.test.ts

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,43 @@ describe("RequestyHandler", () => {
3838
// Clear mocks
3939
jest.clearAllMocks()
4040

41-
// Setup mock create function
42-
mockCreate = jest.fn()
41+
// Setup mock create function that preserves params
42+
let lastParams: any
43+
mockCreate = jest.fn().mockImplementation((params) => {
44+
lastParams = params
45+
return {
46+
[Symbol.asyncIterator]: async function* () {
47+
yield {
48+
choices: [{ delta: { content: "Hello" } }],
49+
}
50+
yield {
51+
choices: [{ delta: { content: " world" } }],
52+
usage: {
53+
prompt_tokens: 30,
54+
completion_tokens: 10,
55+
prompt_tokens_details: {
56+
cached_tokens: 15,
57+
caching_tokens: 5,
58+
},
59+
},
60+
}
61+
},
62+
}
63+
})
4364

4465
// Mock OpenAI constructor
4566
;(OpenAI as jest.MockedClass<typeof OpenAI>).mockImplementation(
4667
() =>
4768
({
4869
chat: {
4970
completions: {
50-
create: mockCreate,
71+
create: (params: any) => {
72+
// Store params for verification
73+
const result = mockCreate(params)
74+
// Make params available for test assertions
75+
;(result as any).params = params
76+
return result
77+
},
5178
},
5279
},
5380
}) as unknown as OpenAI,
@@ -122,7 +149,12 @@ describe("RequestyHandler", () => {
122149
},
123150
])
124151

125-
expect(mockCreate).toHaveBeenCalledWith({
152+
// Get the actual params that were passed
153+
const calls = mockCreate.mock.calls
154+
expect(calls.length).toBe(1)
155+
const actualParams = calls[0][0]
156+
157+
expect(actualParams).toEqual({
126158
model: defaultOptions.requestyModelId,
127159
temperature: 0,
128160
messages: [

src/api/providers/openai.ts

Lines changed: 66 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@ import { convertToSimpleMessages } from "../transform/simple-format"
1515
import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
1616
import { BaseProvider } from "./base-provider"
1717
import { XmlMatcher } from "../../utils/xml-matcher"
18-
19-
const DEEP_SEEK_DEFAULT_TEMPERATURE = 0.6
18+
import { DEEP_SEEK_DEFAULT_TEMPERATURE } from "./constants"
2019

2120
export const defaultHeaders = {
2221
"HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline",
@@ -25,6 +24,8 @@ export const defaultHeaders = {
2524

2625
export interface OpenAiHandlerOptions extends ApiHandlerOptions {}
2726

27+
const AZURE_AI_INFERENCE_PATH = "/models/chat/completions"
28+
2829
export class OpenAiHandler extends BaseProvider implements SingleCompletionHandler {
2930
protected options: OpenAiHandlerOptions
3031
private client: OpenAI
@@ -35,17 +36,19 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
3536

3637
const baseURL = this.options.openAiBaseUrl ?? "https://api.openai.com/v1"
3738
const apiKey = this.options.openAiApiKey ?? "not-provided"
38-
let urlHost: string
39-
40-
try {
41-
urlHost = new URL(this.options.openAiBaseUrl ?? "").host
42-
} catch (error) {
43-
// Likely an invalid `openAiBaseUrl`; we're still working on
44-
// proper settings validation.
45-
urlHost = ""
46-
}
39+
const isAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
40+
const urlHost = this._getUrlHost(this.options.openAiBaseUrl)
41+
const isAzureOpenAi = urlHost === "azure.com" || urlHost.endsWith(".azure.com") || options.openAiUseAzure
4742

48-
if (urlHost === "azure.com" || urlHost.endsWith(".azure.com") || options.openAiUseAzure) {
43+
if (isAzureAiInference) {
44+
// Azure AI Inference Service (e.g., for DeepSeek) uses a different path structure
45+
this.client = new OpenAI({
46+
baseURL,
47+
apiKey,
48+
defaultHeaders,
49+
defaultQuery: { "api-version": this.options.azureApiVersion || "2024-05-01-preview" },
50+
})
51+
} else if (isAzureOpenAi) {
4952
// Azure API shape slightly differs from the core API shape:
5053
// https://github.com/openai/openai-node?tab=readme-ov-file#microsoft-azure-openai
5154
this.client = new AzureOpenAI({
@@ -64,6 +67,8 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
6467
const modelUrl = this.options.openAiBaseUrl ?? ""
6568
const modelId = this.options.openAiModelId ?? ""
6669
const enabledR1Format = this.options.openAiR1FormatEnabled ?? false
70+
const isAzureAiInference = this._isAzureAiInference(modelUrl)
71+
const urlHost = this._getUrlHost(modelUrl)
6772
const deepseekReasoner = modelId.includes("deepseek-reasoner") || enabledR1Format
6873
const ark = modelUrl.includes(".volces.com")
6974
if (modelId.startsWith("o3-mini")) {
@@ -132,7 +137,10 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
132137
requestOptions.max_tokens = modelInfo.maxTokens
133138
}
134139

135-
const stream = await this.client.chat.completions.create(requestOptions)
140+
const stream = await this.client.chat.completions.create(
141+
requestOptions,
142+
isAzureAiInference ? { path: AZURE_AI_INFERENCE_PATH } : {},
143+
)
136144

137145
const matcher = new XmlMatcher(
138146
"think",
@@ -185,7 +193,10 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
185193
: [systemMessage, ...convertToOpenAiMessages(messages)],
186194
}
187195

188-
const response = await this.client.chat.completions.create(requestOptions)
196+
const response = await this.client.chat.completions.create(
197+
requestOptions,
198+
this._isAzureAiInference(modelUrl) ? { path: AZURE_AI_INFERENCE_PATH } : {},
199+
)
189200

190201
yield {
191202
type: "text",
@@ -212,12 +223,16 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
212223

213224
async completePrompt(prompt: string): Promise<string> {
214225
try {
226+
const isAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
215227
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
216228
model: this.getModel().id,
217229
messages: [{ role: "user", content: prompt }],
218230
}
219231

220-
const response = await this.client.chat.completions.create(requestOptions)
232+
const response = await this.client.chat.completions.create(
233+
requestOptions,
234+
isAzureAiInference ? { path: AZURE_AI_INFERENCE_PATH } : {},
235+
)
221236
return response.choices[0]?.message.content || ""
222237
} catch (error) {
223238
if (error instanceof Error) {
@@ -233,19 +248,24 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
233248
messages: Anthropic.Messages.MessageParam[],
234249
): ApiStream {
235250
if (this.options.openAiStreamingEnabled ?? true) {
236-
const stream = await this.client.chat.completions.create({
237-
model: modelId,
238-
messages: [
239-
{
240-
role: "developer",
241-
content: `Formatting re-enabled\n${systemPrompt}`,
242-
},
243-
...convertToOpenAiMessages(messages),
244-
],
245-
stream: true,
246-
stream_options: { include_usage: true },
247-
reasoning_effort: this.getModel().info.reasoningEffort,
248-
})
251+
const methodIsAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
252+
253+
const stream = await this.client.chat.completions.create(
254+
{
255+
model: modelId,
256+
messages: [
257+
{
258+
role: "developer",
259+
content: `Formatting re-enabled\n${systemPrompt}`,
260+
},
261+
...convertToOpenAiMessages(messages),
262+
],
263+
stream: true,
264+
stream_options: { include_usage: true },
265+
reasoning_effort: this.getModel().info.reasoningEffort,
266+
},
267+
methodIsAzureAiInference ? { path: AZURE_AI_INFERENCE_PATH } : {},
268+
)
249269

250270
yield* this.handleStreamResponse(stream)
251271
} else {
@@ -260,7 +280,12 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
260280
],
261281
}
262282

263-
const response = await this.client.chat.completions.create(requestOptions)
283+
const methodIsAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
284+
285+
const response = await this.client.chat.completions.create(
286+
requestOptions,
287+
methodIsAzureAiInference ? { path: AZURE_AI_INFERENCE_PATH } : {},
288+
)
264289

265290
yield {
266291
type: "text",
@@ -289,6 +314,18 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
289314
}
290315
}
291316
}
317+
private _getUrlHost(baseUrl?: string): string {
318+
try {
319+
return new URL(baseUrl ?? "").host
320+
} catch (error) {
321+
return ""
322+
}
323+
}
324+
325+
private _isAzureAiInference(baseUrl?: string): boolean {
326+
const urlHost = this._getUrlHost(baseUrl)
327+
return urlHost.endsWith(".services.ai.azure.com")
328+
}
292329
}
293330

294331
export async function getOpenAiModels(baseUrl?: string, apiKey?: string) {

0 commit comments

Comments
 (0)