Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 115 additions & 4 deletions src/api/providers/__tests__/openai.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { OpenAiHandler } from "../openai"
import { ApiHandlerOptions } from "../../../shared/api"
import { Anthropic } from "@anthropic-ai/sdk"
import { DEEP_SEEK_DEFAULT_TEMPERATURE } from "../constants"

// Mock OpenAI client
const mockCreate = jest.fn()
Expand Down Expand Up @@ -202,10 +203,13 @@ describe("OpenAiHandler", () => {
it("should complete prompt successfully", async () => {
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("Test response")
expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.openAiModelId,
messages: [{ role: "user", content: "Test prompt" }],
})
expect(mockCreate).toHaveBeenCalledWith(
{
model: mockOptions.openAiModelId,
messages: [{ role: "user", content: "Test prompt" }],
},
{},
)
})

it("should handle API errors", async () => {
Expand Down Expand Up @@ -241,4 +245,111 @@ describe("OpenAiHandler", () => {
expect(model.info).toBeDefined()
})
})

describe("Azure AI Inference Service", () => {
const azureOptions = {
...mockOptions,
openAiBaseUrl: "https://test.services.ai.azure.com",
openAiModelId: "deepseek-v3",
azureApiVersion: "2024-05-01-preview",
}

it("should initialize with Azure AI Inference Service configuration", () => {
const azureHandler = new OpenAiHandler(azureOptions)
expect(azureHandler).toBeInstanceOf(OpenAiHandler)
expect(azureHandler.getModel().id).toBe(azureOptions.openAiModelId)
})

it("should handle streaming responses with Azure AI Inference Service", async () => {
const azureHandler = new OpenAiHandler(azureOptions)
const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: "Hello!",
},
]

const stream = azureHandler.createMessage(systemPrompt, messages)
const chunks: any[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}

expect(chunks.length).toBeGreaterThan(0)
const textChunks = chunks.filter((chunk) => chunk.type === "text")
expect(textChunks).toHaveLength(1)
expect(textChunks[0].text).toBe("Test response")

// Verify the API call was made with correct Azure AI Inference Service path
expect(mockCreate).toHaveBeenCalledWith(
{
model: azureOptions.openAiModelId,
messages: [
{ role: "system", content: systemPrompt },
{ role: "user", content: "Hello!" },
],
stream: true,
stream_options: { include_usage: true },
temperature: 0,
},
{ path: "/models/chat/completions" },
)
})

it("should handle non-streaming responses with Azure AI Inference Service", async () => {
const azureHandler = new OpenAiHandler({
...azureOptions,
openAiStreamingEnabled: false,
})
const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: "Hello!",
},
]

const stream = azureHandler.createMessage(systemPrompt, messages)
const chunks: any[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}

expect(chunks.length).toBeGreaterThan(0)
const textChunk = chunks.find((chunk) => chunk.type === "text")
const usageChunk = chunks.find((chunk) => chunk.type === "usage")

expect(textChunk).toBeDefined()
expect(textChunk?.text).toBe("Test response")
expect(usageChunk).toBeDefined()
expect(usageChunk?.inputTokens).toBe(10)
expect(usageChunk?.outputTokens).toBe(5)

// Verify the API call was made with correct Azure AI Inference Service path
expect(mockCreate).toHaveBeenCalledWith(
{
model: azureOptions.openAiModelId,
messages: [
{ role: "user", content: systemPrompt },
{ role: "user", content: "Hello!" },
],
},
{ path: "/models/chat/completions" },
)
})

it("should handle completePrompt with Azure AI Inference Service", async () => {
const azureHandler = new OpenAiHandler(azureOptions)
const result = await azureHandler.completePrompt("Test prompt")
expect(result).toBe("Test response")
expect(mockCreate).toHaveBeenCalledWith(
{
model: azureOptions.openAiModelId,
messages: [{ role: "user", content: "Test prompt" }],
},
{ path: "/models/chat/completions" },
)
})
})
})
40 changes: 36 additions & 4 deletions src/api/providers/__tests__/requesty.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,43 @@ describe("RequestyHandler", () => {
// Clear mocks
jest.clearAllMocks()

// Setup mock create function
mockCreate = jest.fn()
// Setup mock create function that preserves params
let lastParams: any
mockCreate = jest.fn().mockImplementation((params) => {
lastParams = params
return {
[Symbol.asyncIterator]: async function* () {
yield {
choices: [{ delta: { content: "Hello" } }],
}
yield {
choices: [{ delta: { content: " world" } }],
usage: {
prompt_tokens: 30,
completion_tokens: 10,
prompt_tokens_details: {
cached_tokens: 15,
caching_tokens: 5,
},
},
}
},
}
})

// Mock OpenAI constructor
;(OpenAI as jest.MockedClass<typeof OpenAI>).mockImplementation(
() =>
({
chat: {
completions: {
create: mockCreate,
create: (params: any) => {
// Store params for verification
const result = mockCreate(params)
// Make params available for test assertions
;(result as any).params = params
return result
},
},
},
}) as unknown as OpenAI,
Expand Down Expand Up @@ -122,7 +149,12 @@ describe("RequestyHandler", () => {
},
])

expect(mockCreate).toHaveBeenCalledWith({
// Get the actual params that were passed
const calls = mockCreate.mock.calls
expect(calls.length).toBe(1)
const actualParams = calls[0][0]

expect(actualParams).toEqual({
model: defaultOptions.requestyModelId,
temperature: 0,
messages: [
Expand Down
103 changes: 83 additions & 20 deletions src/api/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ import { convertToSimpleMessages } from "../transform/simple-format"
import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
import { BaseProvider } from "./base-provider"
import { XmlMatcher } from "../../utils/xml-matcher"

const DEEP_SEEK_DEFAULT_TEMPERATURE = 0.6
import { DEEP_SEEK_DEFAULT_TEMPERATURE } from "./constants"

export const defaultHeaders = {
"HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline",
Expand All @@ -25,13 +24,17 @@ export const defaultHeaders = {

export interface OpenAiHandlerOptions extends ApiHandlerOptions {}

const AZURE_AI_INFERENCE_PATH = "/models/chat/completions"

export class OpenAiHandler extends BaseProvider implements SingleCompletionHandler {
protected options: OpenAiHandlerOptions
private client: OpenAI
private isAzure: boolean

constructor(options: OpenAiHandlerOptions) {
super()
this.options = options
this.isAzure = options.openAiUseAzure ?? false

const baseURL = this.options.openAiBaseUrl ?? "https://api.openai.com/v1"
const apiKey = this.options.openAiApiKey ?? "not-provided"
Expand All @@ -45,7 +48,18 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
urlHost = ""
}

if (urlHost === "azure.com" || urlHost.endsWith(".azure.com") || options.openAiUseAzure) {
const isAzureAiInference = urlHost.endsWith(".services.ai.azure.com")
const isAzureOpenAi = urlHost === "azure.com" || urlHost.endsWith(".azure.com") || options.openAiUseAzure

if (isAzureAiInference) {
// Azure AI Inference Service (e.g., for DeepSeek) uses a different path structure
this.client = new OpenAI({
baseURL,
apiKey,
defaultHeaders,
defaultQuery: { "api-version": this.options.azureApiVersion || "2024-05-01-preview" },
})
} else if (isAzureOpenAi) {
// Azure API shape slightly differs from the core API shape:
// https://github.com/openai/openai-node?tab=readme-ov-file#microsoft-azure-openai
this.client = new AzureOpenAI({
Expand All @@ -64,6 +78,14 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
const modelUrl = this.options.openAiBaseUrl ?? ""
const modelId = this.options.openAiModelId ?? ""
const enabledR1Format = this.options.openAiR1FormatEnabled ?? false
// Add Azure AI Inference check within this method
let urlHost: string
try {
urlHost = new URL(modelUrl).host
} catch (error) {
urlHost = ""
}
const isAzureAiInference = urlHost.endsWith(".services.ai.azure.com")
const deepseekReasoner = modelId.includes("deepseek-reasoner") || enabledR1Format
const ark = modelUrl.includes(".volces.com")
if (modelId.startsWith("o3-mini")) {
Expand Down Expand Up @@ -132,7 +154,10 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
requestOptions.max_tokens = modelInfo.maxTokens
}

const stream = await this.client.chat.completions.create(requestOptions)
const stream = await this.client.chat.completions.create(
requestOptions,
isAzureAiInference ? { path: AZURE_AI_INFERENCE_PATH } : {},
)

const matcher = new XmlMatcher(
"think",
Expand Down Expand Up @@ -185,7 +210,10 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
: [systemMessage, ...convertToOpenAiMessages(messages)],
}

const response = await this.client.chat.completions.create(requestOptions)
const response = await this.client.chat.completions.create(
requestOptions,
isAzureAiInference ? { path: AZURE_AI_INFERENCE_PATH } : {},
)

yield {
type: "text",
Expand All @@ -212,12 +240,23 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl

async completePrompt(prompt: string): Promise<string> {
try {
// Add Azure AI Inference check within this method
let urlHost: string
try {
urlHost = new URL(this.options.openAiBaseUrl ?? "").host
} catch (error) {
urlHost = ""
}
const isAzureAiInference = urlHost.endsWith(".services.ai.azure.com")
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
model: this.getModel().id,
messages: [{ role: "user", content: prompt }],
}

const response = await this.client.chat.completions.create(requestOptions)
const response = await this.client.chat.completions.create(
requestOptions,
isAzureAiInference ? { path: AZURE_AI_INFERENCE_PATH } : {},
)
return response.choices[0]?.message.content || ""
} catch (error) {
if (error instanceof Error) {
Expand All @@ -233,19 +272,31 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
messages: Anthropic.Messages.MessageParam[],
): ApiStream {
if (this.options.openAiStreamingEnabled ?? true) {
const stream = await this.client.chat.completions.create({
model: modelId,
messages: [
{
role: "developer",
content: `Formatting re-enabled\n${systemPrompt}`,
},
...convertToOpenAiMessages(messages),
],
stream: true,
stream_options: { include_usage: true },
reasoning_effort: this.getModel().info.reasoningEffort,
})
// Add Azure AI Inference check within this method scope
let methodUrlHost: string
try {
methodUrlHost = new URL(this.options.openAiBaseUrl ?? "").host
} catch (error) {
methodUrlHost = ""
}
const methodIsAzureAiInference = methodUrlHost.endsWith(".services.ai.azure.com")

const stream = await this.client.chat.completions.create(
{
model: modelId,
messages: [
{
role: "developer",
content: `Formatting re-enabled\n${systemPrompt}`,
},
...convertToOpenAiMessages(messages),
],
stream: true,
stream_options: { include_usage: true },
reasoning_effort: this.getModel().info.reasoningEffort,
},
methodIsAzureAiInference ? { path: AZURE_AI_INFERENCE_PATH } : {},
)

yield* this.handleStreamResponse(stream)
} else {
Expand All @@ -260,7 +311,19 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
],
}

const response = await this.client.chat.completions.create(requestOptions)
// Add Azure AI Inference check within this method scope
let methodUrlHost: string
try {
methodUrlHost = new URL(this.options.openAiBaseUrl ?? "").host
} catch (error) {
methodUrlHost = ""
}
const methodIsAzureAiInference = methodUrlHost.endsWith(".services.ai.azure.com")

const response = await this.client.chat.completions.create(
requestOptions,
methodIsAzureAiInference ? { path: AZURE_AI_INFERENCE_PATH } : {},
)

yield {
type: "text",
Expand Down