Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 115 additions & 4 deletions src/api/providers/__tests__/openai.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { OpenAiHandler } from "../openai"
import { ApiHandlerOptions } from "../../../shared/api"
import { Anthropic } from "@anthropic-ai/sdk"
import { DEEP_SEEK_DEFAULT_TEMPERATURE } from "../openai"

// Mock OpenAI client
const mockCreate = jest.fn()
Expand Down Expand Up @@ -202,10 +203,13 @@ describe("OpenAiHandler", () => {
it("should complete prompt successfully", async () => {
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("Test response")
expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.openAiModelId,
messages: [{ role: "user", content: "Test prompt" }],
})
expect(mockCreate).toHaveBeenCalledWith(
{
model: mockOptions.openAiModelId,
messages: [{ role: "user", content: "Test prompt" }],
},
{},
)
})

it("should handle API errors", async () => {
Expand Down Expand Up @@ -241,4 +245,111 @@ describe("OpenAiHandler", () => {
expect(model.info).toBeDefined()
})
})

describe("Azure AI Inference Service", () => {
const azureOptions = {
...mockOptions,
openAiBaseUrl: "https://test.services.ai.azure.com",
openAiModelId: "deepseek-v3",
azureApiVersion: "2024-05-01-preview",
}

it("should initialize with Azure AI Inference Service configuration", () => {
const azureHandler = new OpenAiHandler(azureOptions)
expect(azureHandler).toBeInstanceOf(OpenAiHandler)
expect(azureHandler.getModel().id).toBe(azureOptions.openAiModelId)
})

it("should handle streaming responses with Azure AI Inference Service", async () => {
const azureHandler = new OpenAiHandler(azureOptions)
const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: "Hello!",
},
]

const stream = azureHandler.createMessage(systemPrompt, messages)
const chunks: any[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}

expect(chunks.length).toBeGreaterThan(0)
const textChunks = chunks.filter((chunk) => chunk.type === "text")
expect(textChunks).toHaveLength(1)
expect(textChunks[0].text).toBe("Test response")

// Verify the API call was made with correct Azure AI Inference Service path
expect(mockCreate).toHaveBeenCalledWith(
{
model: azureOptions.openAiModelId,
messages: [
{ role: "system", content: systemPrompt },
{ role: "user", content: "Hello!" },
],
stream: true,
stream_options: { include_usage: true },
temperature: 0,
},
{ path: "/models/chat/completions" },
)
})

it("should handle non-streaming responses with Azure AI Inference Service", async () => {
const azureHandler = new OpenAiHandler({
...azureOptions,
openAiStreamingEnabled: false,
})
const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: "Hello!",
},
]

const stream = azureHandler.createMessage(systemPrompt, messages)
const chunks: any[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}

expect(chunks.length).toBeGreaterThan(0)
const textChunk = chunks.find((chunk) => chunk.type === "text")
const usageChunk = chunks.find((chunk) => chunk.type === "usage")

expect(textChunk).toBeDefined()
expect(textChunk?.text).toBe("Test response")
expect(usageChunk).toBeDefined()
expect(usageChunk?.inputTokens).toBe(10)
expect(usageChunk?.outputTokens).toBe(5)

// Verify the API call was made with correct Azure AI Inference Service path
expect(mockCreate).toHaveBeenCalledWith(
{
model: azureOptions.openAiModelId,
messages: [
{ role: "user", content: systemPrompt },
{ role: "user", content: "Hello!" },
],
},
{ path: "/models/chat/completions" },
)
})

it("should handle completePrompt with Azure AI Inference Service", async () => {
const azureHandler = new OpenAiHandler(azureOptions)
const result = await azureHandler.completePrompt("Test prompt")
expect(result).toBe("Test response")
expect(mockCreate).toHaveBeenCalledWith(
{
model: azureOptions.openAiModelId,
messages: [{ role: "user", content: "Test prompt" }],
},
{ path: "/models/chat/completions" },
)
})
})
})
102 changes: 83 additions & 19 deletions src/api/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
import { BaseProvider } from "./base-provider"
import { XmlMatcher } from "../../utils/xml-matcher"

const DEEP_SEEK_DEFAULT_TEMPERATURE = 0.6
export const DEEP_SEEK_DEFAULT_TEMPERATURE = 0.6

export const defaultHeaders = {
"HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline",
Expand Down Expand Up @@ -45,7 +45,18 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
urlHost = ""
}

if (urlHost === "azure.com" || urlHost.endsWith(".azure.com") || options.openAiUseAzure) {
const isAzureAiInference = urlHost.endsWith(".services.ai.azure.com")
const isAzureOpenAi = urlHost === "azure.com" || urlHost.endsWith(".azure.com") || options.openAiUseAzure

if (isAzureAiInference) {
// Azure AI Inference Service (e.g., for DeepSeek) uses a different path structure
this.client = new OpenAI({
baseURL,
apiKey,
defaultHeaders,
defaultQuery: { "api-version": this.options.azureApiVersion || "2024-05-01-preview" },
})
} else if (isAzureOpenAi) {
// Azure API shape slightly differs from the core API shape:
// https://github.com/openai/openai-node?tab=readme-ov-file#microsoft-azure-openai
this.client = new AzureOpenAI({
Expand All @@ -64,6 +75,15 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
const modelUrl = this.options.openAiBaseUrl ?? ""
const modelId = this.options.openAiModelId ?? ""
const enabledR1Format = this.options.openAiR1FormatEnabled ?? false
// Add Azure AI Inference check within this method
let urlHost: string
try {
urlHost = new URL(modelUrl).host
} catch (error) {
urlHost = ""
}
const isAzureAiInference = urlHost.endsWith(".services.ai.azure.com")
const azureAiInferencePath = "/models/chat/completions" // Path for Azure AI Inference
const deepseekReasoner = modelId.includes("deepseek-reasoner") || enabledR1Format
const ark = modelUrl.includes(".volces.com")
if (modelId.startsWith("o3-mini")) {
Expand Down Expand Up @@ -132,7 +152,10 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
requestOptions.max_tokens = modelInfo.maxTokens
}

const stream = await this.client.chat.completions.create(requestOptions)
const stream = await this.client.chat.completions.create(
requestOptions,
isAzureAiInference ? { path: azureAiInferencePath } : {},
)

const matcher = new XmlMatcher(
"think",
Expand Down Expand Up @@ -185,7 +208,10 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
: [systemMessage, ...convertToOpenAiMessages(messages)],
}

const response = await this.client.chat.completions.create(requestOptions)
const response = await this.client.chat.completions.create(
requestOptions,
isAzureAiInference ? { path: azureAiInferencePath } : {},
)

yield {
type: "text",
Expand All @@ -212,12 +238,24 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl

async completePrompt(prompt: string): Promise<string> {
try {
// Add Azure AI Inference check within this method
let urlHost: string
try {
urlHost = new URL(this.options.openAiBaseUrl ?? "").host
} catch (error) {
urlHost = ""
}
const isAzureAiInference = urlHost.endsWith(".services.ai.azure.com")
const azureAiInferencePath = "/models/chat/completions" // Path for Azure AI Inference
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
model: this.getModel().id,
messages: [{ role: "user", content: prompt }],
}

const response = await this.client.chat.completions.create(requestOptions)
const response = await this.client.chat.completions.create(
requestOptions,
isAzureAiInference ? { path: azureAiInferencePath } : {},
)
return response.choices[0]?.message.content || ""
} catch (error) {
if (error instanceof Error) {
Expand All @@ -233,19 +271,32 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
messages: Anthropic.Messages.MessageParam[],
): ApiStream {
if (this.options.openAiStreamingEnabled ?? true) {
const stream = await this.client.chat.completions.create({
model: modelId,
messages: [
{
role: "developer",
content: `Formatting re-enabled\n${systemPrompt}`,
},
...convertToOpenAiMessages(messages),
],
stream: true,
stream_options: { include_usage: true },
reasoning_effort: this.getModel().info.reasoningEffort,
})
// Add Azure AI Inference check within this method scope
let methodUrlHost: string
try {
methodUrlHost = new URL(this.options.openAiBaseUrl ?? "").host
} catch (error) {
methodUrlHost = ""
}
const methodIsAzureAiInference = methodUrlHost.endsWith(".services.ai.azure.com")
const methodAzureAiInferencePath = "/models/chat/completions"

const stream = await this.client.chat.completions.create(
{
model: modelId,
messages: [
{
role: "developer",
content: `Formatting re-enabled\n${systemPrompt}`,
},
...convertToOpenAiMessages(messages),
],
stream: true,
stream_options: { include_usage: true },
reasoning_effort: this.getModel().info.reasoningEffort,
},
methodIsAzureAiInference ? { path: methodAzureAiInferencePath } : {},
)

yield* this.handleStreamResponse(stream)
} else {
Expand All @@ -260,7 +311,20 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
],
}

const response = await this.client.chat.completions.create(requestOptions)
// Add Azure AI Inference check within this method scope
let methodUrlHost: string
try {
methodUrlHost = new URL(this.options.openAiBaseUrl ?? "").host
} catch (error) {
methodUrlHost = ""
}
const methodIsAzureAiInference = methodUrlHost.endsWith(".services.ai.azure.com")
const methodAzureAiInferencePath = "/models/chat/completions"

const response = await this.client.chat.completions.create(
requestOptions,
methodIsAzureAiInference ? { path: methodAzureAiInferencePath } : {},
)

yield {
type: "text",
Expand Down
Loading