Skip to content

Commit a28d50e

Browse files
authored
fix: detect Claude models by name for API protocol selection (#5840)
* fix: detect Claude models by name for API protocol selection - Modified getApiProtocol to accept modelId parameter - Added check for 'claude' in model name (case-insensitive) - Updated Task.ts to pass model ID to getApiProtocol - Added comprehensive tests for the new logic Fixes #5830 * fix: limit Claude model detection to vertex and bedrock providers only - Modified getApiProtocol to only detect Claude models by name when provider is vertex or bedrock - Added comprehensive unit tests for getApiProtocol function as requested in PR review - This ensures Claude models are only auto-detected for providers that need it
1 parent a7771a3 commit a28d50e

File tree

4 files changed

+197
-5
lines changed

4 files changed

+197
-5
lines changed
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import { describe, it, expect } from "vitest"
2+
import { getApiProtocol } from "../provider-settings.js"
3+
4+
describe("getApiProtocol", () => {
5+
describe("Anthropic-style providers", () => {
6+
it("should return 'anthropic' for anthropic provider", () => {
7+
expect(getApiProtocol("anthropic")).toBe("anthropic")
8+
expect(getApiProtocol("anthropic", "gpt-4")).toBe("anthropic")
9+
})
10+
11+
it("should return 'anthropic' for claude-code provider", () => {
12+
expect(getApiProtocol("claude-code")).toBe("anthropic")
13+
expect(getApiProtocol("claude-code", "some-model")).toBe("anthropic")
14+
})
15+
})
16+
17+
describe("Vertex provider with Claude models", () => {
18+
it("should return 'anthropic' for vertex provider with claude models", () => {
19+
expect(getApiProtocol("vertex", "claude-3-opus")).toBe("anthropic")
20+
expect(getApiProtocol("vertex", "Claude-3-Sonnet")).toBe("anthropic")
21+
expect(getApiProtocol("vertex", "CLAUDE-instant")).toBe("anthropic")
22+
expect(getApiProtocol("vertex", "anthropic/claude-3-haiku")).toBe("anthropic")
23+
})
24+
25+
it("should return 'openai' for vertex provider with non-claude models", () => {
26+
expect(getApiProtocol("vertex", "gpt-4")).toBe("openai")
27+
expect(getApiProtocol("vertex", "gemini-pro")).toBe("openai")
28+
expect(getApiProtocol("vertex", "llama-2")).toBe("openai")
29+
})
30+
})
31+
32+
describe("Bedrock provider with Claude models", () => {
33+
it("should return 'anthropic' for bedrock provider with claude models", () => {
34+
expect(getApiProtocol("bedrock", "claude-3-opus")).toBe("anthropic")
35+
expect(getApiProtocol("bedrock", "Claude-3-Sonnet")).toBe("anthropic")
36+
expect(getApiProtocol("bedrock", "CLAUDE-instant")).toBe("anthropic")
37+
expect(getApiProtocol("bedrock", "anthropic.claude-v2")).toBe("anthropic")
38+
})
39+
40+
it("should return 'openai' for bedrock provider with non-claude models", () => {
41+
expect(getApiProtocol("bedrock", "gpt-4")).toBe("openai")
42+
expect(getApiProtocol("bedrock", "titan-text")).toBe("openai")
43+
expect(getApiProtocol("bedrock", "llama-2")).toBe("openai")
44+
})
45+
})
46+
47+
describe("Other providers with Claude models", () => {
48+
it("should return 'openai' for non-vertex/bedrock providers with claude models", () => {
49+
expect(getApiProtocol("openrouter", "claude-3-opus")).toBe("openai")
50+
expect(getApiProtocol("openai", "claude-3-sonnet")).toBe("openai")
51+
expect(getApiProtocol("litellm", "claude-instant")).toBe("openai")
52+
expect(getApiProtocol("ollama", "claude-model")).toBe("openai")
53+
})
54+
})
55+
56+
describe("Edge cases", () => {
57+
it("should return 'openai' when provider is undefined", () => {
58+
expect(getApiProtocol(undefined)).toBe("openai")
59+
expect(getApiProtocol(undefined, "claude-3-opus")).toBe("openai")
60+
})
61+
62+
it("should return 'openai' when model is undefined", () => {
63+
expect(getApiProtocol("openai")).toBe("openai")
64+
expect(getApiProtocol("vertex")).toBe("openai")
65+
expect(getApiProtocol("bedrock")).toBe("openai")
66+
})
67+
68+
it("should handle empty strings", () => {
69+
expect(getApiProtocol("vertex", "")).toBe("openai")
70+
expect(getApiProtocol("bedrock", "")).toBe("openai")
71+
})
72+
73+
it("should be case-insensitive for claude detection", () => {
74+
expect(getApiProtocol("vertex", "CLAUDE-3-OPUS")).toBe("anthropic")
75+
expect(getApiProtocol("bedrock", "claude-3-opus")).toBe("anthropic")
76+
expect(getApiProtocol("vertex", "ClAuDe-InStAnT")).toBe("anthropic")
77+
})
78+
})
79+
})

packages/types/src/provider-settings.ts

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,23 @@ export const getModelId = (settings: ProviderSettings): string | undefined => {
303303
// Providers that use Anthropic-style API protocol
304304
export const ANTHROPIC_STYLE_PROVIDERS: ProviderName[] = ["anthropic", "claude-code"]
305305

306-
// Helper function to determine API protocol for a provider
307-
export const getApiProtocol = (provider: ProviderName | undefined): "anthropic" | "openai" => {
308-
return provider && ANTHROPIC_STYLE_PROVIDERS.includes(provider) ? "anthropic" : "openai"
306+
// Helper function to determine API protocol for a provider and model
307+
export const getApiProtocol = (provider: ProviderName | undefined, modelId?: string): "anthropic" | "openai" => {
308+
// First check if the provider is an Anthropic-style provider
309+
if (provider && ANTHROPIC_STYLE_PROVIDERS.includes(provider)) {
310+
return "anthropic"
311+
}
312+
313+
// For vertex and bedrock providers, check if the model ID contains "claude" (case-insensitive)
314+
if (
315+
provider &&
316+
(provider === "vertex" || provider === "bedrock") &&
317+
modelId &&
318+
modelId.toLowerCase().includes("claude")
319+
) {
320+
return "anthropic"
321+
}
322+
323+
// Default to OpenAI protocol
324+
return "openai"
309325
}

src/core/task/Task.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import {
2323
TelemetryEventName,
2424
TodoItem,
2525
getApiProtocol,
26+
getModelId,
2627
} from "@roo-code/types"
2728
import { TelemetryService } from "@roo-code/telemetry"
2829
import { CloudService } from "@roo-code/cloud"
@@ -1211,8 +1212,9 @@ export class Task extends EventEmitter<ClineEvents> {
12111212
// take a few seconds. For the best UX we show a placeholder api_req_started
12121213
// message with a loading spinner as this happens.
12131214

1214-
// Determine API protocol based on provider
1215-
const apiProtocol = getApiProtocol(this.apiConfiguration.apiProvider)
1215+
// Determine API protocol based on provider and model
1216+
const modelId = getModelId(this.apiConfiguration)
1217+
const apiProtocol = getApiProtocol(this.apiConfiguration.apiProvider, modelId)
12161218

12171219
await this.say(
12181220
"api_req_started",

src/core/task/__tests__/Task.spec.ts

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,5 +1398,100 @@ describe("Cline", () => {
13981398
expect(task.diffStrategy).toBeUndefined()
13991399
})
14001400
})
1401+
1402+
describe("getApiProtocol", () => {
1403+
it("should determine API protocol based on provider and model", async () => {
1404+
// Test with Anthropic provider
1405+
const anthropicConfig = {
1406+
...mockApiConfig,
1407+
apiProvider: "anthropic" as const,
1408+
apiModelId: "gpt-4",
1409+
}
1410+
const anthropicTask = new Task({
1411+
provider: mockProvider,
1412+
apiConfiguration: anthropicConfig,
1413+
task: "test task",
1414+
startTask: false,
1415+
})
1416+
// Should use anthropic protocol even with non-claude model
1417+
expect(anthropicTask.apiConfiguration.apiProvider).toBe("anthropic")
1418+
1419+
// Test with OpenRouter provider and Claude model
1420+
const openrouterClaudeConfig = {
1421+
apiProvider: "openrouter" as const,
1422+
openRouterModelId: "anthropic/claude-3-opus",
1423+
}
1424+
const openrouterClaudeTask = new Task({
1425+
provider: mockProvider,
1426+
apiConfiguration: openrouterClaudeConfig,
1427+
task: "test task",
1428+
startTask: false,
1429+
})
1430+
expect(openrouterClaudeTask.apiConfiguration.apiProvider).toBe("openrouter")
1431+
1432+
// Test with OpenRouter provider and non-Claude model
1433+
const openrouterGptConfig = {
1434+
apiProvider: "openrouter" as const,
1435+
openRouterModelId: "openai/gpt-4",
1436+
}
1437+
const openrouterGptTask = new Task({
1438+
provider: mockProvider,
1439+
apiConfiguration: openrouterGptConfig,
1440+
task: "test task",
1441+
startTask: false,
1442+
})
1443+
expect(openrouterGptTask.apiConfiguration.apiProvider).toBe("openrouter")
1444+
1445+
// Test with various Claude model formats
1446+
const claudeModelFormats = [
1447+
"claude-3-opus",
1448+
"Claude-3-Sonnet",
1449+
"CLAUDE-instant",
1450+
"anthropic/claude-3-haiku",
1451+
"some-provider/claude-model",
1452+
]
1453+
1454+
for (const modelId of claudeModelFormats) {
1455+
const config = {
1456+
apiProvider: "openai" as const,
1457+
openAiModelId: modelId,
1458+
}
1459+
const task = new Task({
1460+
provider: mockProvider,
1461+
apiConfiguration: config,
1462+
task: "test task",
1463+
startTask: false,
1464+
})
1465+
// Verify the model ID contains claude (case-insensitive)
1466+
expect(modelId.toLowerCase()).toContain("claude")
1467+
}
1468+
})
1469+
1470+
it("should handle edge cases for API protocol detection", async () => {
1471+
// Test with undefined provider
1472+
const undefinedProviderConfig = {
1473+
apiModelId: "claude-3-opus",
1474+
}
1475+
const undefinedProviderTask = new Task({
1476+
provider: mockProvider,
1477+
apiConfiguration: undefinedProviderConfig,
1478+
task: "test task",
1479+
startTask: false,
1480+
})
1481+
expect(undefinedProviderTask.apiConfiguration.apiProvider).toBeUndefined()
1482+
1483+
// Test with no model ID
1484+
const noModelConfig = {
1485+
apiProvider: "openai" as const,
1486+
}
1487+
const noModelTask = new Task({
1488+
provider: mockProvider,
1489+
apiConfiguration: noModelConfig,
1490+
task: "test task",
1491+
startTask: false,
1492+
})
1493+
expect(noModelTask.apiConfiguration.apiProvider).toBe("openai")
1494+
})
1495+
})
14011496
})
14021497
})

0 commit comments

Comments
 (0)