Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions src/api/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import {
import type { ApiHandlerOptions } from "../../shared/api"

import { XmlMatcher } from "../../utils/xml-matcher"
import { extractApiVersionFromUrl } from "../../utils/azure-url-parser"

import { convertToOpenAiMessages } from "../transform/openai-format"
import { convertToR1Format } from "../transform/r1-format"
Expand All @@ -35,12 +36,24 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
super()
this.options = options

const baseURL = this.options.openAiBaseUrl ?? "https://api.openai.com/v1"
const originalBaseURL = this.options.openAiBaseUrl ?? "https://api.openai.com/v1"
const apiKey = this.options.openAiApiKey ?? "not-provided"
const isAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
const urlHost = this._getUrlHost(this.options.openAiBaseUrl)
const isAzureOpenAi = urlHost === "azure.com" || urlHost.endsWith(".azure.com") || options.openAiUseAzure

// Extract API version from URL if present and no explicit azureApiVersion is set
let effectiveApiVersion = this.options.azureApiVersion
let baseURL = originalBaseURL

// Extract version for both Azure OpenAI and Azure AI Inference
if ((isAzureOpenAi || isAzureAiInference) && !effectiveApiVersion) {
const extractedVersion = extractApiVersionFromUrl(originalBaseURL)
if (extractedVersion) {
effectiveApiVersion = extractedVersion
}
}

const headers = {
...DEFAULT_HEADERS,
...(this.options.openAiHeaders || {}),
Expand All @@ -49,23 +62,23 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
if (isAzureAiInference) {
// Azure AI Inference Service (e.g., for DeepSeek) uses a different path structure
this.client = new OpenAI({
baseURL,
baseURL: originalBaseURL, // Keep original URL for AI Inference
apiKey,
defaultHeaders: headers,
defaultQuery: { "api-version": this.options.azureApiVersion || "2024-05-01-preview" },
defaultQuery: { "api-version": effectiveApiVersion || "2024-05-01-preview" },
})
} else if (isAzureOpenAi) {
// Azure API shape slightly differs from the core API shape:
// https://github.com/openai/openai-node?tab=readme-ov-file#microsoft-azure-openai
this.client = new AzureOpenAI({
baseURL,
baseURL: originalBaseURL, // Use original URL to maintain exact same behavior
apiKey,
apiVersion: this.options.azureApiVersion || azureOpenAiDefaultApiVersion,
apiVersion: effectiveApiVersion || azureOpenAiDefaultApiVersion,
defaultHeaders: headers,
})
} else {
this.client = new OpenAI({
baseURL,
baseURL: originalBaseURL,
apiKey,
defaultHeaders: headers,
})
Expand Down
190 changes: 190 additions & 0 deletions src/utils/__tests__/azure-url-parser.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import { describe, it, expect } from "vitest"
import {
extractApiVersionFromUrl,
isAzureOpenAiUrl,
removeApiVersionFromUrl,
isValidAzureApiVersion,
} from "../azure-url-parser"

describe("azure-url-parser", () => {
describe("isValidAzureApiVersion", () => {
it("should return true for valid API version format YYYY-MM-DD", () => {
expect(isValidAzureApiVersion("2024-05-01")).toBe(true)
expect(isValidAzureApiVersion("2023-12-31")).toBe(true)
})

it("should return true for valid API version format YYYY-MM-DD-preview", () => {
expect(isValidAzureApiVersion("2024-05-01-preview")).toBe(true)
expect(isValidAzureApiVersion("2024-12-01-preview")).toBe(true)
})

it("should return false for invalid API version formats", () => {
expect(isValidAzureApiVersion("2024-5-1")).toBe(false) // Missing leading zeros
expect(isValidAzureApiVersion("24-05-01")).toBe(false) // Two-digit year
expect(isValidAzureApiVersion("2024/05/01")).toBe(false) // Wrong separator
expect(isValidAzureApiVersion("2024-05-01-alpha")).toBe(false) // Wrong suffix
expect(isValidAzureApiVersion("invalid-version")).toBe(false)
expect(isValidAzureApiVersion("")).toBe(false)
})
})

describe("extractApiVersionFromUrl", () => {
it("should extract API version from Azure OpenAI URL", () => {
const url =
"https://myresource.openai.azure.com/openai/deployments/mymodel/chat/completions?api-version=2024-05-01-preview"
const result = extractApiVersionFromUrl(url)
expect(result).toBe("2024-05-01-preview")
})

it("should extract API version from URL with multiple query parameters", () => {
const url =
"https://myresource.openai.azure.com/openai/deployments/mymodel/chat/completions?foo=bar&api-version=2024-12-01-preview&baz=qux"
const result = extractApiVersionFromUrl(url)
expect(result).toBe("2024-12-01-preview")
})

it("should return null when no api-version parameter exists", () => {
const url = "https://api.openai.com/v1/chat/completions"
const result = extractApiVersionFromUrl(url)
expect(result).toBeNull()
})

it("should return null for invalid URLs", () => {
const invalidUrl = "not-a-valid-url"
const result = extractApiVersionFromUrl(invalidUrl)
expect(result).toBeNull()
})

it("should handle empty api-version parameter", () => {
const url = "https://myresource.openai.azure.com/openai/deployments/mymodel/chat/completions?api-version="
const result = extractApiVersionFromUrl(url)
expect(result).toBe("")
})

it("should handle URL without query parameters", () => {
const url = "https://myresource.openai.azure.com/openai/deployments/mymodel/chat/completions"
const result = extractApiVersionFromUrl(url)
expect(result).toBeNull()
})

it("should handle URL with duplicate api-version parameters", () => {
const url =
"https://myresource.openai.azure.com/openai/deployments/mymodel/chat/completions?api-version=2024-05-01&api-version=2024-12-01"
const result = extractApiVersionFromUrl(url)
// URL.searchParams.get returns the first value
expect(result).toBe("2024-05-01")
})

it("should handle URL with malformed api-version parameter", () => {
const url =
"https://myresource.openai.azure.com/openai/deployments/mymodel/chat/completions?api-version=invalid-format"
const result = extractApiVersionFromUrl(url)
expect(result).toBe("invalid-format") // Still extracts it, validation is separate
})
})

describe("isAzureOpenAiUrl", () => {
it("should return true for Azure OpenAI URLs with .openai.azure.com", () => {
const url = "https://myresource.openai.azure.com/openai/deployments/mymodel/chat/completions"
const result = isAzureOpenAiUrl(url)
expect(result).toBe(true)
})

it("should return true for Azure URLs ending with .azure.com", () => {
const url = "https://myservice.azure.com/api/v1"
const result = isAzureOpenAiUrl(url)
expect(result).toBe(true)
})

it("should return true for URLs with /openai/deployments/ path", () => {
const url = "https://custom-domain.com/openai/deployments/mymodel/chat/completions"
const result = isAzureOpenAiUrl(url)
expect(result).toBe(true)
})

it("should return false for regular OpenAI URLs", () => {
const url = "https://api.openai.com/v1/chat/completions"
const result = isAzureOpenAiUrl(url)
expect(result).toBe(false)
})

it("should return false for other API URLs", () => {
const url = "https://api.anthropic.com/v1/messages"
const result = isAzureOpenAiUrl(url)
expect(result).toBe(false)
})

it("should return false for invalid URLs", () => {
const invalidUrl = "not-a-valid-url"
const result = isAzureOpenAiUrl(invalidUrl)
expect(result).toBe(false)
})

it("should handle case insensitive hostname matching", () => {
const url = "https://MYRESOURCE.OPENAI.AZURE.COM/openai/deployments/mymodel"
const result = isAzureOpenAiUrl(url)
expect(result).toBe(true)
})

it("should return false for malicious URLs trying to include Azure domain", () => {
const maliciousUrl = "https://evil.openai.azure.com.attacker.com/api/v1"
const result = isAzureOpenAiUrl(maliciousUrl)
expect(result).toBe(false)
})

it("should return true for root openai.azure.com domain", () => {
const url = "https://openai.azure.com/api/v1"
const result = isAzureOpenAiUrl(url)
expect(result).toBe(true)
})

it("should return false for Azure AI Inference Service URLs", () => {
const url = "https://myservice.services.ai.azure.com/models/deployments"
const result = isAzureOpenAiUrl(url)
expect(result).toBe(false)
})
})

describe("removeApiVersionFromUrl", () => {
it("should remove api-version parameter from URL", () => {
const url =
"https://myresource.openai.azure.com/openai/deployments/mymodel/chat/completions?api-version=2024-05-01-preview"
const result = removeApiVersionFromUrl(url)
expect(result).toBe("https://myresource.openai.azure.com/openai/deployments/mymodel/chat/completions")
})

it("should remove api-version parameter while preserving other parameters", () => {
const url =
"https://myresource.openai.azure.com/openai/deployments/mymodel/chat/completions?foo=bar&api-version=2024-05-01-preview&baz=qux"
const result = removeApiVersionFromUrl(url)
expect(result).toBe(
"https://myresource.openai.azure.com/openai/deployments/mymodel/chat/completions?foo=bar&baz=qux",
)
})

it("should return original URL when no api-version parameter exists", () => {
const url = "https://api.openai.com/v1/chat/completions?foo=bar"
const result = removeApiVersionFromUrl(url)
expect(result).toBe(url)
})

it("should return original URL for invalid URLs", () => {
const invalidUrl = "not-a-valid-url"
const result = removeApiVersionFromUrl(invalidUrl)
expect(result).toBe(invalidUrl)
})

it("should handle URL with only api-version parameter", () => {
const url =
"https://myresource.openai.azure.com/openai/deployments/mymodel/chat/completions?api-version=2024-05-01-preview"
const result = removeApiVersionFromUrl(url)
expect(result).toBe("https://myresource.openai.azure.com/openai/deployments/mymodel/chat/completions")
})

it("should handle URL without query parameters", () => {
const url = "https://myresource.openai.azure.com/openai/deployments/mymodel/chat/completions"
const result = removeApiVersionFromUrl(url)
expect(result).toBe(url)
})
})
})
82 changes: 82 additions & 0 deletions src/utils/azure-url-parser.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/**
* Utility functions for parsing Azure OpenAI URLs and extracting API versions
*/

/**
* Validates if a string is a valid Azure API version format
* @param version The version string to validate
* @returns True if the version follows Azure API version format (YYYY-MM-DD or YYYY-MM-DD-preview)
*/
export function isValidAzureApiVersion(version: string): boolean {
if (!version) return false

// Azure API versions follow the pattern: YYYY-MM-DD or YYYY-MM-DD-preview
const versionPattern = /^\d{4}-\d{2}-\d{2}(-preview)?$/
return versionPattern.test(version)
}

/**
* Extracts the API version from an Azure OpenAI URL query parameter
* @param url The Azure OpenAI URL that may contain an api-version query parameter
* @returns The extracted API version string, or null if not found
*/
export function extractApiVersionFromUrl(url: string): string | null {
try {
const urlObj = new URL(url)
const apiVersion = urlObj.searchParams.get("api-version")

// Validate the extracted version format
if (apiVersion && !isValidAzureApiVersion(apiVersion)) {
console.warn(`Invalid Azure API version format: ${apiVersion}`)
}

return apiVersion
} catch (error) {
// Invalid URL format
return null
}
}

/**
* Checks if a URL appears to be an Azure OpenAI URL
* @param url The URL to check
* @returns True if the URL appears to be an Azure OpenAI URL
*/
export function isAzureOpenAiUrl(url: string): boolean {
try {
const urlObj = new URL(url)
const host = urlObj.host.toLowerCase()

// Exclude Azure AI Inference Service URLs
if (host.endsWith(".services.ai.azure.com")) {
return false
}

// Check for Azure OpenAI hostname patterns
// Use endsWith to prevent matching malicious URLs like evil.openai.azure.com.attacker.com
return (
host.endsWith(".openai.azure.com") ||
host === "openai.azure.com" ||
host.endsWith(".azure.com") ||
urlObj.pathname.includes("/openai/deployments/")
)
} catch (error) {
return false
}
}

/**
* Removes the api-version query parameter from a URL
* @param url The URL to clean
* @returns The URL without the api-version parameter
*/
export function removeApiVersionFromUrl(url: string): string {
try {
const urlObj = new URL(url)
urlObj.searchParams.delete("api-version")
return urlObj.toString()
} catch (error) {
// Return original URL if parsing fails
return url
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ import {
openAiModelInfoSaneDefaults,
} from "@roo-code/types"

import {
extractApiVersionFromUrl,
isAzureOpenAiUrl,
isValidAzureApiVersion,
} from "../../../../../src/utils/azure-url-parser"

import { ExtensionMessage } from "@roo/ExtensionMessage"

import { useAppTranslation } from "@src/i18n/TranslationContext"
Expand Down Expand Up @@ -41,6 +47,12 @@ export const OpenAICompatible = ({
const [azureApiVersionSelected, setAzureApiVersionSelected] = useState(!!apiConfiguration?.azureApiVersion)
const [openAiLegacyFormatSelected, setOpenAiLegacyFormatSelected] = useState(!!apiConfiguration?.openAiLegacyFormat)

// Check if API version can be extracted from the base URL
const baseUrl = apiConfiguration?.openAiBaseUrl || ""
const extractedApiVersion = extractApiVersionFromUrl(baseUrl)
const isAzureUrl = isAzureOpenAiUrl(baseUrl)
const showApiVersionExtraction = isAzureUrl && extractedApiVersion && !azureApiVersionSelected
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable 'showApiVersionExtraction' is computed based on the extracted API version from the Base URL but is not used anywhere in the UI. Consider using it to display a notice (or auto-fill the Azure API version field) to help users understand that the API version was detected from their Base URL, thereby reducing configuration duplication.


const [openAiModels, setOpenAiModels] = useState<Record<string, ModelInfo> | null>(null)

const [customHeaders, setCustomHeaders] = useState<[string, string][]>(() => {
Expand Down Expand Up @@ -194,12 +206,31 @@ export const OpenAICompatible = ({
}}>
{t("settings:modelInfo.azureApiVersion")}
</Checkbox>
{showApiVersionExtraction && (
<div
className="text-sm text-vscode-descriptionForeground ml-6 mb-2"
dangerouslySetInnerHTML={{
__html: t("settings:modelInfo.azureApiVersionDetected", { version: extractedApiVersion }),
}}
/>
)}
{azureApiVersionSelected && (
<VSCodeTextField
value={apiConfiguration?.azureApiVersion || ""}
onInput={handleInputChange("azureApiVersion")}
placeholder={`Default: ${azureOpenAiDefaultApiVersion}`}
className="w-full mt-1"
style={{
borderColor: (() => {
const value = apiConfiguration?.azureApiVersion
if (!value) {
return "var(--vscode-input-border)"
}
return isValidAzureApiVersion(value)
? "var(--vscode-charts-green)"
: "var(--vscode-errorForeground)"
})(),
}}
/>
)}
</div>
Expand Down
Loading