diff --git a/packages/types/src/__tests__/provider-settings.test.ts b/packages/types/src/__tests__/provider-settings.test.ts index 87c5bbcc1c8..8277320289b 100644 --- a/packages/types/src/__tests__/provider-settings.test.ts +++ b/packages/types/src/__tests__/provider-settings.test.ts @@ -12,6 +12,12 @@ describe("getApiProtocol", () => { expect(getApiProtocol("claude-code")).toBe("anthropic") expect(getApiProtocol("claude-code", "some-model")).toBe("anthropic") }) + + it("should return 'anthropic' for bedrock provider", () => { + expect(getApiProtocol("bedrock")).toBe("anthropic") + expect(getApiProtocol("bedrock", "gpt-4")).toBe("anthropic") + expect(getApiProtocol("bedrock", "claude-3-opus")).toBe("anthropic") + }) }) describe("Vertex provider with Claude models", () => { @@ -27,25 +33,14 @@ describe("getApiProtocol", () => { expect(getApiProtocol("vertex", "gemini-pro")).toBe("openai") expect(getApiProtocol("vertex", "llama-2")).toBe("openai") }) - }) - - describe("Bedrock provider with Claude models", () => { - it("should return 'anthropic' for bedrock provider with claude models", () => { - expect(getApiProtocol("bedrock", "claude-3-opus")).toBe("anthropic") - expect(getApiProtocol("bedrock", "Claude-3-Sonnet")).toBe("anthropic") - expect(getApiProtocol("bedrock", "CLAUDE-instant")).toBe("anthropic") - expect(getApiProtocol("bedrock", "anthropic.claude-v2")).toBe("anthropic") - }) - it("should return 'openai' for bedrock provider with non-claude models", () => { - expect(getApiProtocol("bedrock", "gpt-4")).toBe("openai") - expect(getApiProtocol("bedrock", "titan-text")).toBe("openai") - expect(getApiProtocol("bedrock", "llama-2")).toBe("openai") + it("should return 'openai' for vertex provider without model", () => { + expect(getApiProtocol("vertex")).toBe("openai") }) }) - describe("Other providers with Claude models", () => { - it("should return 'openai' for non-vertex/bedrock providers with claude models", () => { + describe("Other providers", () => { + it("should return 'openai' for non-anthropic providers regardless of model", () => { expect(getApiProtocol("openrouter", "claude-3-opus")).toBe("openai") expect(getApiProtocol("openai", "claude-3-sonnet")).toBe("openai") expect(getApiProtocol("litellm", "claude-instant")).toBe("openai") @@ -59,20 +54,13 @@ describe("getApiProtocol", () => { expect(getApiProtocol(undefined, "claude-3-opus")).toBe("openai") }) - it("should return 'openai' when model is undefined", () => { - expect(getApiProtocol("openai")).toBe("openai") - expect(getApiProtocol("vertex")).toBe("openai") - expect(getApiProtocol("bedrock")).toBe("openai") - }) - it("should handle empty strings", () => { expect(getApiProtocol("vertex", "")).toBe("openai") - expect(getApiProtocol("bedrock", "")).toBe("openai") }) it("should be case-insensitive for claude detection", () => { expect(getApiProtocol("vertex", "CLAUDE-3-OPUS")).toBe("anthropic") - expect(getApiProtocol("bedrock", "claude-3-opus")).toBe("anthropic") + expect(getApiProtocol("vertex", "claude-3-opus")).toBe("anthropic") expect(getApiProtocol("vertex", "ClAuDe-InStAnT")).toBe("anthropic") }) }) diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index be74ae6bb4c..511e803cbd6 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -301,7 +301,7 @@ export const getModelId = (settings: ProviderSettings): string | undefined => { } // Providers that use Anthropic-style API protocol -export const ANTHROPIC_STYLE_PROVIDERS: ProviderName[] = ["anthropic", "claude-code"] +export const ANTHROPIC_STYLE_PROVIDERS: ProviderName[] = ["anthropic", "claude-code", "bedrock"] // Helper function to determine API protocol for a provider and model export const getApiProtocol = (provider: ProviderName | undefined, modelId?: string): "anthropic" | "openai" => { @@ -310,13 +310,8 @@ export const getApiProtocol = (provider: ProviderName | undefined, modelId?: str return "anthropic" } - // For vertex and bedrock providers, check if the model ID contains "claude" (case-insensitive) - if ( - provider && - (provider === "vertex" || provider === "bedrock") && - modelId && - modelId.toLowerCase().includes("claude") - ) { + // For vertex provider, check if the model ID contains "claude" (case-insensitive) + if (provider && provider === "vertex" && modelId && modelId.toLowerCase().includes("claude")) { return "anthropic" }