Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 159 additions & 0 deletions src/api/providers/__tests__/anthropic.custom-models.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import { describe, it, expect, vi, beforeEach } from "vitest"
import { Anthropic } from "@anthropic-ai/sdk"

import { AnthropicHandler } from "../anthropic"
import { ANTHROPIC_DEFAULT_MAX_TOKENS } from "@roo-code/types"

// Mock the Anthropic SDK
vi.mock("@anthropic-ai/sdk", () => ({
Anthropic: vi.fn().mockImplementation(() => ({
messages: {
create: vi.fn(),
countTokens: vi.fn(),
},
})),
}))

describe("AnthropicHandler - Custom Models", () => {
let handler: AnthropicHandler
let mockClient: any

beforeEach(() => {
vi.clearAllMocks()
mockClient = {
messages: {
create: vi.fn().mockResolvedValue({
content: [{ type: "text", text: "Test response" }],
}),
countTokens: vi.fn().mockResolvedValue({ input_tokens: 100 }),
},
}
;(Anthropic as any).mockImplementation(() => mockClient)
})

describe("getModel", () => {
it("should use predefined model when no custom base URL is set", () => {
handler = new AnthropicHandler({
apiKey: "test-key",
apiModelId: "claude-3-opus-20240229",
} as any)

const model = handler.getModel()

expect(model.id).toBe("claude-3-opus-20240229")
expect(model.info).toBeDefined()
expect(model.info.maxTokens).toBe(4096) // Predefined model's max tokens
})

it("should fallback to default model when invalid model is provided without custom base URL", () => {
handler = new AnthropicHandler({
apiKey: "test-key",
apiModelId: "custom-model-xyz",
} as any)

const model = handler.getModel()

expect(model.id).toBe("claude-sonnet-4-20250514") // Default model
expect(model.info).toBeDefined()
})

it("should allow custom model when custom base URL is set", () => {
handler = new AnthropicHandler({
apiKey: "test-key",
apiModelId: "glm-4.6-cc-max",
anthropicBaseUrl: "https://api.z.ai/api/anthropic",
} as any)

const model = handler.getModel()

expect(model.id).toBe("glm-4.6-cc-max")
expect(model.info).toBeDefined()
expect(model.info.maxTokens).toBe(ANTHROPIC_DEFAULT_MAX_TOKENS) // Default for custom models
expect(model.info.contextWindow).toBe(200_000) // Default context window
expect(model.info.supportsImages).toBe(false) // Conservative default
expect(model.info.supportsPromptCache).toBe(false) // Conservative default
})

it("should still use predefined model info when using known model with custom base URL", () => {
handler = new AnthropicHandler({
apiKey: "test-key",
apiModelId: "claude-3-opus-20240229",
anthropicBaseUrl: "https://api.z.ai/api/anthropic",
} as any)

const model = handler.getModel()

expect(model.id).toBe("claude-3-opus-20240229")
expect(model.info.maxTokens).toBe(4096) // Should use predefined model's settings
expect(model.info.supportsImages).toBe(true) // From predefined model
expect(model.info.supportsPromptCache).toBe(true) // From predefined model
})

it("should handle custom models with special characters", () => {
handler = new AnthropicHandler({
apiKey: "test-key",
apiModelId: "glm-4.5v",
anthropicBaseUrl: "https://api.z.ai/api/anthropic",
} as any)

const model = handler.getModel()

expect(model.id).toBe("glm-4.5v")
expect(model.info).toBeDefined()
})

it("should use auth token when anthropicUseAuthToken is true", () => {
const handler = new AnthropicHandler({
apiKey: "test-token",
apiModelId: "custom-model",
anthropicBaseUrl: "https://api.z.ai/api/anthropic",
anthropicUseAuthToken: true,
} as any)

// Verify the Anthropic client was created with authToken instead of apiKey
expect(Anthropic).toHaveBeenCalledWith({
baseURL: "https://api.z.ai/api/anthropic",
authToken: "test-token",
})
})
})

describe("completePrompt with custom models", () => {
it("should use custom model ID when making API calls", async () => {
handler = new AnthropicHandler({
apiKey: "test-key",
apiModelId: "glm-4.6-cc-max",
anthropicBaseUrl: "https://api.z.ai/api/anthropic",
} as any)

await handler.completePrompt("Test prompt")

expect(mockClient.messages.create).toHaveBeenCalledWith(
expect.objectContaining({
model: "glm-4.6-cc-max",
messages: [{ role: "user", content: "Test prompt" }],
}),
)
})
})

describe("countTokens with custom models", () => {
it("should use custom model ID for token counting", async () => {
handler = new AnthropicHandler({
apiKey: "test-key",
apiModelId: "glm-4.6-cc-max",
anthropicBaseUrl: "https://api.z.ai/api/anthropic",
} as any)

const content = [{ type: "text" as const, text: "Test content" }]
await handler.countTokens(content)

expect(mockClient.messages.countTokens).toHaveBeenCalledWith(
expect.objectContaining({
model: "glm-4.6-cc-max",
messages: [{ role: "user", content }],
}),
)
})
})
})
57 changes: 42 additions & 15 deletions src/api/providers/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -245,21 +245,48 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa

getModel() {
const modelId = this.options.apiModelId
let id = modelId && modelId in anthropicModels ? (modelId as AnthropicModelId) : anthropicDefaultModelId
let info: ModelInfo = anthropicModels[id]

// If 1M context beta is enabled for Claude Sonnet 4 or 4.5, update the model info
if ((id === "claude-sonnet-4-20250514" || id === "claude-sonnet-4-5") && this.options.anthropicBeta1MContext) {
// Use the tier pricing for 1M context
const tier = info.tiers?.[0]
if (tier) {
info = {
...info,
contextWindow: tier.contextWindow,
inputPrice: tier.inputPrice,
outputPrice: tier.outputPrice,
cacheWritesPrice: tier.cacheWritesPrice,
cacheReadsPrice: tier.cacheReadsPrice,

// When using a custom base URL, allow any model ID to be used
// This enables compatibility with services like z.ai that provide
// Anthropic-compatible endpoints with custom models
const isUsingCustomBaseUrl = !!this.options.anthropicBaseUrl

let id: string
let info: ModelInfo

if (isUsingCustomBaseUrl && modelId && !(modelId in anthropicModels)) {
// Custom model with custom base URL - use the model ID as-is
// and provide default model info since we don't know the specifics
id = modelId
info = {
maxTokens: ANTHROPIC_DEFAULT_MAX_TOKENS,
contextWindow: 200_000, // Default context window
supportsImages: false, // Conservative default
supportsPromptCache: false, // Conservative default
inputPrice: 0, // Unknown pricing
outputPrice: 0, // Unknown pricing
}
} else {
// Standard Anthropic model or no custom base URL
id = modelId && modelId in anthropicModels ? (modelId as AnthropicModelId) : anthropicDefaultModelId
info = anthropicModels[id as AnthropicModelId]

// If 1M context beta is enabled for Claude Sonnet 4 or 4.5, update the model info
if (
(id === "claude-sonnet-4-20250514" || id === "claude-sonnet-4-5") &&
this.options.anthropicBeta1MContext
) {
// Use the tier pricing for 1M context
const tier = info.tiers?.[0]
if (tier) {
info = {
...info,
contextWindow: tier.contextWindow,
inputPrice: tier.inputPrice,
outputPrice: tier.outputPrice,
cacheWritesPrice: tier.cacheWritesPrice,
cacheReadsPrice: tier.cacheReadsPrice,
}
}
}
}
Expand Down
116 changes: 59 additions & 57 deletions webview-ui/src/components/settings/ApiOptions.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -689,66 +689,68 @@ const ApiOptions = ({
<Featherless apiConfiguration={apiConfiguration} setApiConfigurationField={setApiConfigurationField} />
)}

{selectedProviderModels.length > 0 && (
<>
<div>
<label className="block font-medium mb-1">{t("settings:providers.model")}</label>
<Select
value={selectedModelId === "custom-arn" ? "custom-arn" : selectedModelId}
onValueChange={(value) => {
setApiConfigurationField("apiModelId", value)

// Clear custom ARN if not using custom ARN option.
if (value !== "custom-arn" && selectedProvider === "bedrock") {
setApiConfigurationField("awsCustomArn", "")
}

// Clear reasoning effort when switching models to allow the new model's default to take effect
// This is especially important for GPT-5 models which default to "medium"
if (selectedProvider === "openai-native") {
setApiConfigurationField("reasoningEffort", undefined)
}
}}>
<SelectTrigger className="w-full">
<SelectValue placeholder={t("settings:common.select")} />
</SelectTrigger>
<SelectContent>
{selectedProviderModels.map((option) => (
<SelectItem key={option.value} value={option.value}>
{option.label}
</SelectItem>
))}
{selectedProvider === "bedrock" && (
<SelectItem value="custom-arn">{t("settings:labels.useCustomArn")}</SelectItem>
)}
</SelectContent>
</Select>
</div>
{/* Don't show model selector for Anthropic when using custom base URL - it's handled in the Anthropic component */}
{selectedProviderModels.length > 0 &&
!(selectedProvider === "anthropic" && apiConfiguration?.anthropicBaseUrl) && (
<>
<div>
<label className="block font-medium mb-1">{t("settings:providers.model")}</label>
<Select
value={selectedModelId === "custom-arn" ? "custom-arn" : selectedModelId}
onValueChange={(value) => {
setApiConfigurationField("apiModelId", value)

// Clear custom ARN if not using custom ARN option.
if (value !== "custom-arn" && selectedProvider === "bedrock") {
setApiConfigurationField("awsCustomArn", "")
}

// Clear reasoning effort when switching models to allow the new model's default to take effect
// This is especially important for GPT-5 models which default to "medium"
if (selectedProvider === "openai-native") {
setApiConfigurationField("reasoningEffort", undefined)
}
}}>
<SelectTrigger className="w-full">
<SelectValue placeholder={t("settings:common.select")} />
</SelectTrigger>
<SelectContent>
{selectedProviderModels.map((option) => (
<SelectItem key={option.value} value={option.value}>
{option.label}
</SelectItem>
))}
{selectedProvider === "bedrock" && (
<SelectItem value="custom-arn">{t("settings:labels.useCustomArn")}</SelectItem>
)}
</SelectContent>
</Select>
</div>

{/* Show error if a deprecated model is selected */}
{selectedModelInfo?.deprecated && (
<ApiErrorMessage errorMessage={t("settings:validation.modelDeprecated")} />
)}
{/* Show error if a deprecated model is selected */}
{selectedModelInfo?.deprecated && (
<ApiErrorMessage errorMessage={t("settings:validation.modelDeprecated")} />
)}

{selectedProvider === "bedrock" && selectedModelId === "custom-arn" && (
<BedrockCustomArn
apiConfiguration={apiConfiguration}
setApiConfigurationField={setApiConfigurationField}
/>
)}
{selectedProvider === "bedrock" && selectedModelId === "custom-arn" && (
<BedrockCustomArn
apiConfiguration={apiConfiguration}
setApiConfigurationField={setApiConfigurationField}
/>
)}

{/* Only show model info if not deprecated */}
{!selectedModelInfo?.deprecated && (
<ModelInfoView
apiProvider={selectedProvider}
selectedModelId={selectedModelId}
modelInfo={selectedModelInfo}
isDescriptionExpanded={isDescriptionExpanded}
setIsDescriptionExpanded={setIsDescriptionExpanded}
/>
)}
</>
)}
{/* Only show model info if not deprecated */}
{!selectedModelInfo?.deprecated && (
<ModelInfoView
apiProvider={selectedProvider}
selectedModelId={selectedModelId}
modelInfo={selectedModelInfo}
isDescriptionExpanded={isDescriptionExpanded}
setIsDescriptionExpanded={setIsDescriptionExpanded}
/>
)}
</>
)}

<ThinkingBudget
key={`${selectedProvider}-${selectedModelId}`}
Expand Down
Loading