Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 161 additions & 36 deletions src/api/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,40 +35,64 @@
super()
this.options = options

const baseURL = this.options.openAiBaseUrl ?? "https://api.openai.com/v1"
const apiKey = this.options.openAiApiKey ?? "not-provided"
const isAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
const urlHost = this._getUrlHost(this.options.openAiBaseUrl)
const isAzureOpenAi = urlHost === "azure.com" || urlHost.endsWith(".azure.com") || options.openAiUseAzure
// Use azureApiVersion as primary indicator (like Cline), then fall back to URL patterns
const isAzureOpenAi =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The enhanced Azure endpoint detection using azureApiVersion is clear, but the logic to compute isAzureOpenAi is repeated (e.g., here and again in createMessage). Consider extracting this check into a helper function to avoid redundancy.

!!this.options.azureApiVersion ||
urlHost === "azure.com" ||
urlHost.endsWith(".azure.com") ||
options.openAiUseAzure

// Extract base URL for Azure endpoints that might include full paths
let baseURL = this.options.openAiBaseUrl ?? "https://api.openai.com/v1"
if (isAzureOpenAi && this.options.openAiBaseUrl) {
baseURL = this._extractAzureBaseUrl(this.options.openAiBaseUrl)
}

const headers = {
...DEFAULT_HEADERS,
...(this.options.openAiHeaders || {}),
}

if (isAzureAiInference) {
// Azure AI Inference Service (e.g., for DeepSeek) uses a different path structure
this.client = new OpenAI({
baseURL,
apiKey,
defaultHeaders: headers,
defaultQuery: { "api-version": this.options.azureApiVersion || "2024-05-01-preview" },
})
} else if (isAzureOpenAi) {
// Azure API shape slightly differs from the core API shape:
// https://github.com/openai/openai-node?tab=readme-ov-file#microsoft-azure-openai
this.client = new AzureOpenAI({
baseURL,
apiKey,
apiVersion: this.options.azureApiVersion || azureOpenAiDefaultApiVersion,
defaultHeaders: headers,
})
} else {
this.client = new OpenAI({
baseURL,
apiKey,
defaultHeaders: headers,
})
try {
if (isAzureAiInference) {
// Azure AI Inference Service (e.g., for DeepSeek) uses a different path structure
this.client = new OpenAI({
baseURL,
apiKey,
defaultHeaders: headers,
defaultQuery: { "api-version": this.options.azureApiVersion || "2024-05-01-preview" },
})
} else if (isAzureOpenAi) {
// Azure API shape slightly differs from the core API shape:
// https://github.com/openai/openai-node?tab=readme-ov-file#microsoft-azure-openai
this.client = new AzureOpenAI({
baseURL,
apiKey,
apiVersion: this.options.azureApiVersion || azureOpenAiDefaultApiVersion,
defaultHeaders: headers,
})
} else {
this.client = new OpenAI({
baseURL,
apiKey,
defaultHeaders: headers,
})
}
} catch (error) {
const errorMessage = error instanceof Error ? error.message : String(error)
if (isAzureOpenAi) {
throw new Error(
`Failed to initialize Azure OpenAI client: ${errorMessage}\n` +
`Please ensure:\n` +
`1. Your base URL is correct (e.g., https://myresource.openai.azure.com)\n` +
`2. Your API key is valid\n` +
`3. If using a full endpoint URL, try using just the base URL instead`,
)
}
throw new Error(`Failed to initialize OpenAI client: ${errorMessage}`)
}
}

Expand All @@ -86,9 +110,33 @@
const deepseekReasoner = modelId.includes("deepseek-reasoner") || enabledR1Format
const ark = modelUrl.includes(".volces.com")

// Check if this is an Azure OpenAI endpoint
const urlHost = this._getUrlHost(this.options.openAiBaseUrl)
const isAzureOpenAi =
!!this.options.azureApiVersion ||
urlHost === "azure.com" ||
urlHost.endsWith(".azure.com") ||
!!this.options.openAiUseAzure

if (modelId.includes("o1") || modelId.includes("o3") || modelId.includes("o4")) {
yield* this.handleO3FamilyMessage(modelId, systemPrompt, messages)
return
try {
yield* this.handleO3FamilyMessage(modelId, systemPrompt, messages, isAzureOpenAi)
return
} catch (error) {
if (isAzureOpenAi && error instanceof Error) {
// Check for common Azure-specific errors
if (
error.message.includes("does not support 'system'") ||
error.message.includes("does not support 'developer'")
) {
throw new Error(
`Azure OpenAI reasoning model error: ${error.message}\n` +
`This has been fixed in the latest version. Please ensure you're using the updated code.`,
)
}
}
throw error
}
}

if (this.options.openAiStreamingEnabled ?? true) {
Expand Down Expand Up @@ -287,22 +335,51 @@
modelId: string,
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
isAzureOpenAi: boolean,
): ApiStream {
const modelInfo = this.getModel().info
const methodIsAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)

if (this.options.openAiStreamingEnabled ?? true) {
const isGrokXAI = this._isGrokXAI(this.options.openAiBaseUrl)

const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
model: modelId,
messages: [
// Azure doesn't support "developer" role, so we need to combine system prompt with first user message
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic to merge the system prompt with the first user message for Azure endpoints is duplicated in both streaming and non‐streaming branches (and also in handleO3FamilyMessage). Consider refactoring this into a utility/helper function to enhance maintainability.

This comment was generated because it violated a code review rule: irule_tTqpIuNs8DV0QFGj.

let openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[]
if (isAzureOpenAi) {
const convertedMessages = convertToOpenAiMessages(messages)
if (convertedMessages.length > 0 && convertedMessages[0].role === "user") {
// Combine system prompt with first user message
openAiMessages = [
{
role: "user",
content: `${systemPrompt}\n\n${convertedMessages[0].content}`,
},
...convertedMessages.slice(1),
]
} else {
// If first message isn't a user message, add system prompt as first user message
openAiMessages = [
{
role: "user",
content: systemPrompt,
},
...convertedMessages,
]
}
} else {
// Non-Azure endpoints support "developer" role
openAiMessages = [
{
role: "developer",
content: `Formatting re-enabled\n${systemPrompt}`,
},
...convertToOpenAiMessages(messages),
],
]
}

const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
model: modelId,
messages: openAiMessages,
stream: true,
...(isGrokXAI ? {} : { stream_options: { include_usage: true } }),
reasoning_effort: modelInfo.reasoningEffort,
Expand All @@ -321,15 +398,43 @@

yield* this.handleStreamResponse(stream)
} else {
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
model: modelId,
messages: [
// Azure doesn't support "developer" role, so we need to combine system prompt with first user message
let openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[]
if (isAzureOpenAi) {
const convertedMessages = convertToOpenAiMessages(messages)
if (convertedMessages.length > 0 && convertedMessages[0].role === "user") {
// Combine system prompt with first user message
openAiMessages = [
{
role: "user",
content: `${systemPrompt}\n\n${convertedMessages[0].content}`,
},
...convertedMessages.slice(1),
]
} else {
// If first message isn't a user message, add system prompt as first user message
openAiMessages = [
{
role: "user",
content: systemPrompt,
},
...convertedMessages,
]
}
} else {
// Non-Azure endpoints support "developer" role
openAiMessages = [
{
role: "developer",
content: `Formatting re-enabled\n${systemPrompt}`,
},
...convertToOpenAiMessages(messages),
],
]
}

const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
model: modelId,
messages: openAiMessages,
reasoning_effort: modelInfo.reasoningEffort,
temperature: undefined,
}
Expand Down Expand Up @@ -374,12 +479,32 @@

private _getUrlHost(baseUrl?: string): string {
try {
return new URL(baseUrl ?? "").host
// Extract base URL without query parameters for proper host detection
const url = new URL(baseUrl ?? "")
return url.host
} catch (error) {
return ""
}
}

/**
* Extracts the base URL from a full Azure endpoint URL
* e.g., "https://myresource.openai.azure.com/openai/deployments/gpt-4/chat/completions?api-version=2024-08-01-preview"
* becomes "https://myresource.openai.azure.com"
*/
private _extractAzureBaseUrl(fullUrl: string): string {
try {
const url = new URL(fullUrl)
// For Azure OpenAI, we want just the origin (protocol + host)
if (url.host.includes("azure.com")) {

Check failure

Code scanning / CodeQL

Incomplete URL substring sanitization High

'
azure.com
' can be anywhere in the URL, and arbitrary hosts may come before or after it.

Copilot Autofix

AI 5 months ago

To fix the issue, replace the substring check url.host.includes("azure.com") with a more robust validation mechanism that ensures the host matches an explicit whitelist of allowed domains. This approach prevents bypasses by malicious URLs and ensures that only legitimate Azure domains are accepted.

The fix involves:

  1. Defining a whitelist of allowed Azure domains (e.g., ["azure.com"]).
  2. Using url.host to check if the host matches one of the allowed domains exactly or ends with one of the allowed domains (to account for subdomains like services.ai.azure.com).

Changes are required in the _extractAzureBaseUrl method to implement this validation.


Suggested changeset 1
src/api/providers/openai.ts

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/src/api/providers/openai.ts b/src/api/providers/openai.ts
--- a/src/api/providers/openai.ts
+++ b/src/api/providers/openai.ts
@@ -498,3 +498,4 @@
 			// For Azure OpenAI, we want just the origin (protocol + host)
-			if (url.host.includes("azure.com")) {
+			const allowedAzureDomains = ["azure.com"];
+			if (allowedAzureDomains.some(domain => url.host === domain || url.host.endsWith(`.${domain}`))) {
 				return url.origin
EOF
@@ -498,3 +498,4 @@
// For Azure OpenAI, we want just the origin (protocol + host)
if (url.host.includes("azure.com")) {
const allowedAzureDomains = ["azure.com"];
if (allowedAzureDomains.some(domain => url.host === domain || url.host.endsWith(`.${domain}`))) {
return url.origin
Copilot is powered by AI and may make mistakes. Always verify output.
Unable to commit as this autofix suggestion is now outdated
return url.origin
}
return fullUrl
} catch {
return fullUrl
}
}

private _isGrokXAI(baseUrl?: string): boolean {
const urlHost = this._getUrlHost(baseUrl)
return urlHost.includes("x.ai")
Expand Down