Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions src/api/providers/__tests__/cost.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import { getCost } from "../cost"

describe("getCost", () => {
it("should return the correct cost for Bedrock provider with invokedModelId", () => {
// For 1000 tokens with 25% input (250) and 75% output (750)
// Claude-3-5-sonnet: (0.003/1000 * 250) + (0.015/1000 * 750) = 0.00075 + 0.01125 = 0.012
const cost = getCost("bedrock", "test prompt", "gpt-3.5-turbo", 1000, "claude-3-5-sonnet")
expect(cost).toBeCloseTo(0.012, 5)
})

it("should return 0 for Bedrock provider without invokedModelId", () => {
// Since GPT models are not supported on Bedrock and we've removed the fallback,
// this should return 0
const cost = getCost("bedrock", "test prompt", "any-model", 1000)
expect(cost).toBe(0)
})

it("should return 0 for unknown provider", () => {
const cost = getCost("unknown" as any, "test prompt", "gpt-3.5-turbo", 1000)
expect(cost).toBe(0)
})

it("should use provided input and output tokens when available", () => {
// For specific input (300) and output (700) tokens
// Claude-3-5-sonnet: (0.003/1000 * 300) + (0.015/1000 * 700) = 0.0009 + 0.0105 = 0.0114
const cost = getCost("bedrock", "test prompt", "gpt-3.5-turbo", 1000, "claude-3-5-sonnet", 300, 700)
expect(cost).toBeCloseTo(0.0114, 5)
})

it("should handle cache write and cache read tokens", () => {
// For specific input (300), output (700), cache write (200), and cache read (100) tokens
// Claude-3-5-sonnet:
// Input: (0.003/1000 * 300) = 0.0009
// Output: (0.015/1000 * 700) = 0.0105
// Cache Write: (0.00375/1000 * 200) = 0.00075
// Cache Read: (0.0003/1000 * 100) = 0.00003
// Total: 0.0009 + 0.0105 + 0.00075 + 0.00003 = 0.01218
const cost = getCost("bedrock", "test prompt", "gpt-3.5-turbo", 1000, "claude-3-5-sonnet", 300, 700, 200, 100)
expect(cost).toBeCloseTo(0.01218, 5)
})

it("should handle models without cache pricing", () => {
// For specific input (300), output (700), cache write (200), and cache read (100) tokens
// Claude-3-opus:
// Input: (0.015/1000 * 300) = 0.0045
// Output: (0.075/1000 * 700) = 0.0525
// Cache Write: (0/1000 * 200) = 0
// Cache Read: (0/1000 * 100) = 0
// Total: 0.0045 + 0.0525 + 0 + 0 = 0.057
const cost = getCost("bedrock", "test prompt", "gpt-3.5-turbo", 1000, "claude-3-opus", 300, 700, 200, 100)
expect(cost).toBeCloseTo(0.057, 5)
})
})

describe("getBedrockCost", () => {
it("should return the correct cost for claude-3-5-sonnet", () => {
// For 1000 tokens with 25% input (250) and 75% output (750)
// Claude-3-5-sonnet: (0.003/1000 * 250) + (0.015/1000 * 750) = 0.00075 + 0.01125 = 0.012
const cost = getCost("bedrock", "test prompt", "any-model", 1000, "claude-3-5-sonnet")
expect(cost).toBeCloseTo(0.012, 5)
})

// GPT model tests removed as they are not supported on Bedrock

it("should return 0 for unknown invokedModelId", () => {
const cost = getCost("bedrock", "test prompt", "any-model", 1000, "unknown-model")
expect(cost).toBe(0)
})

it("should return 0 when invokedModelId is not provided", () => {
// Since we've removed the fallback to model-based cost calculation,
// this should return 0
const cost = getCost("bedrock", "test prompt", "any-model", 1000)
expect(cost).toBe(0)
})

it("should handle intelligent prompt router ARN format", () => {
// Test with a full ARN from an intelligent prompt router
// For 1000 tokens with 25% input (250) and 75% output (750)
// Claude-3-5-sonnet: (0.003/1000 * 250) + (0.015/1000 * 750) = 0.00075 + 0.01125 = 0.012
const cost = getCost(
"bedrock",
"test prompt",
"custom-arn",
1000,
"arn:aws:bedrock:us-west-2:699475926481:inference-profile/us.anthropic.claude-3-5-sonnet-20240620-v1:0",
)
expect(cost).toBeCloseTo(0.012, 5)
})

it("should return the correct cost for Amazon Nova Pro", () => {
// For 1000 tokens with 25% input (250) and 75% output (750)
// Amazon Nova Pro: (0.0008/1000 * 250) + (0.0032/1000 * 750) = 0.0002 + 0.0024 = 0.0026
const cost = getCost("bedrock", "test prompt", "any-model", 1000, "amazon.nova-pro")
expect(cost).toBeCloseTo(0.0026, 5)
})

it("should return the correct cost for Amazon Nova Micro", () => {
// For 1000 tokens with 25% input (250) and 75% output (750)
// Amazon Nova Micro: (0.000035/1000 * 250) + (0.00014/1000 * 750) = 0.00000875 + 0.000105 = 0.00011375
const cost = getCost("bedrock", "test prompt", "any-model", 1000, "amazon.nova-micro")
expect(cost).toBeCloseTo(0.00011375, 8)
})

it("should return the correct cost for Amazon Titan Text Express", () => {
// For 1000 tokens with 25% input (250) and 75% output (750)
// Amazon Titan Text Express: (0.0002/1000 * 250) + (0.0006/1000 * 750) = 0.00005 + 0.00045 = 0.0005
const cost = getCost("bedrock", "test prompt", "any-model", 1000, "amazon.titan-text-express")
expect(cost).toBeCloseTo(0.0005, 5)
})

it("should return the correct cost for Amazon Titan Text Lite", () => {
// For 1000 tokens with 25% input (250) and 75% output (750)
// Amazon Titan Text Lite: (0.00015/1000 * 250) + (0.0002/1000 * 750) = 0.0000375 + 0.00015 = 0.0001875
const cost = getCost("bedrock", "test prompt", "any-model", 1000, "amazon.titan-text-lite")
expect(cost).toBeCloseTo(0.0001875, 7)
})

it("should return the correct cost for Amazon Titan Text Embeddings", () => {
// For embeddings, with the default 1:3 input/output split (250 input, 750 output)
// Amazon Titan Text Embeddings: (0.0001/1000 * 250) = 0.000025
// Note: Even though embeddings don't have output tokens, the getCost function
// still splits tokens using a 1:3 ratio by default
const cost = getCost("bedrock", "test prompt", "any-model", 1000, "amazon.titan-text-embeddings")
expect(cost).toBeCloseTo(0.000025, 6)
})

it("should return the correct cost for Llama 3.2 (11B)", () => {
// For 1000 tokens with 25% input (250) and 75% output (750)
// Llama 3.2 (11B): (0.00016/1000 * 250) + (0.00016/1000 * 750) = 0.00004 + 0.00012 = 0.00016
const cost = getCost("bedrock", "test prompt", "any-model", 1000, "llama-3.2-11b")
expect(cost).toBeCloseTo(0.00016, 6)
})

it("should return the correct cost for Llama 3.2 (90B)", () => {
// For 1000 tokens with 25% input (250) and 75% output (750)
// Llama 3.2 (90B): (0.00072/1000 * 250) + (0.00072/1000 * 750) = 0.00018 + 0.00054 = 0.00072
const cost = getCost("bedrock", "test prompt", "any-model", 1000, "llama-3.2-90b")
expect(cost).toBeCloseTo(0.00072, 6)
})

it("should return the correct cost for Llama 3.3 (70B)", () => {
// For 1000 tokens with 25% input (250) and 75% output (750)
// Llama 3.3 (70B): (0.00072/1000 * 250) + (0.00072/1000 * 750) = 0.00018 + 0.00054 = 0.00072
const cost = getCost("bedrock", "test prompt", "any-model", 1000, "llama-3.3-70b")
expect(cost).toBeCloseTo(0.00072, 6)
})
})
104 changes: 102 additions & 2 deletions src/api/providers/bedrock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ export interface StreamEvent {
latencyMs: number
}
}
trace?: {
promptRouter?: {
invokedModelId?: string
}
}
}

export class AwsBedrockHandler extends BaseProvider implements SingleCompletionHandler {
Expand Down Expand Up @@ -252,10 +257,49 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH

// Handle metadata events first
if (streamEvent.metadata?.usage) {
// Check if this is a response from an intelligent prompt router
const invokedModelId = streamEvent.trace?.promptRouter?.invokedModelId

// If invokedModelId is present, extract it from the ARN format
let modelIdForCost: string | undefined
if (invokedModelId) {
// Extract the model name from the ARN
// Example ARN: arn:aws:bedrock:us-west-2:699475926481:inference-profile/us.anthropic.claude-3-5-sonnet-20240620-v1:0
const modelMatch = invokedModelId.match(/\/([^\/]+)(?::|$)/)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ARN model extraction logic is duplicated here and in cost.ts. Consider extracting this into a shared utility function to avoid code duplication and ensure consistency.

if (modelMatch && modelMatch[1]) {
const modelName = modelMatch[1]

// Map the model name to the format expected by the cost calculation function
if (modelName.includes("claude-3-5-sonnet")) {
modelIdForCost = "claude-3-5-sonnet"
} else if (modelName.includes("claude-3-sonnet")) {
modelIdForCost = "claude-3-sonnet"
} else if (modelName.includes("claude-3-opus")) {
modelIdForCost = "claude-3-opus"
} else if (modelName.includes("claude-3-haiku")) {
modelIdForCost = "claude-3-haiku"
} else if (modelName.includes("claude-3-5-haiku")) {
modelIdForCost = "claude-3-5-haiku"
} else if (modelName.includes("claude-3-7-sonnet")) {
modelIdForCost = "claude-3-7-sonnet"
}

logger.debug("Extracted model ID from intelligent prompt router", {
ctx: "bedrock",
originalArn: invokedModelId,
extractedModelId: modelIdForCost,
})
}
}

const inputTokens = streamEvent.metadata.usage.inputTokens || 0
const outputTokens = streamEvent.metadata.usage.outputTokens || 0

yield {
type: "usage",
inputTokens: streamEvent.metadata.usage.inputTokens || 0,
outputTokens: streamEvent.metadata.usage.outputTokens || 0,
inputTokens: inputTokens,
outputTokens: outputTokens,
invokedModelId: modelIdForCost,
}
continue
}
Expand Down Expand Up @@ -491,6 +535,22 @@ Please check:
supportsPromptCache: false,
supportsImages: true,
}
} else if (arnLower.includes("llama3.3") || arnLower.includes("llama-3.3")) {
// Llama 3.3 models
modelInfo = {
maxTokens: 8192,
contextWindow: 128_000,
supportsPromptCache: true,
supportsImages: true,
}
} else if (arnLower.includes("llama3.2") || arnLower.includes("llama-3.2")) {
// Llama 3.2 models
modelInfo = {
maxTokens: 8192,
contextWindow: 128_000,
supportsPromptCache: true,
supportsImages: arnLower.includes("90b") || arnLower.includes("11b"),
}
} else if (arnLower.includes("llama3") || arnLower.includes("llama-3")) {
// Llama 3 models typically have 8192 tokens in Bedrock
modelInfo = {
Expand All @@ -499,6 +559,46 @@ Please check:
supportsPromptCache: false,
supportsImages: arnLower.includes("90b") || arnLower.includes("11b"),
}
} else if (arnLower.includes("titan-text-lite")) {
// Amazon Titan Text Lite
modelInfo = {
maxTokens: 4096,
contextWindow: 8_000,
supportsPromptCache: false,
supportsImages: false,
}
} else if (arnLower.includes("titan-text-express")) {
// Amazon Titan Text Express
modelInfo = {
maxTokens: 4096,
contextWindow: 8_000,
supportsPromptCache: false,
supportsImages: false,
}
} else if (arnLower.includes("titan-text-embeddings")) {
// Amazon Titan Text Embeddings
modelInfo = {
maxTokens: 8192,
contextWindow: 8_000,
supportsPromptCache: false,
supportsImages: false,
}
} else if (arnLower.includes("nova-micro")) {
// Amazon Nova Micro
modelInfo = {
maxTokens: 4096,
contextWindow: 128_000,
supportsPromptCache: false,
supportsImages: false,
}
} else if (arnLower.includes("nova-lite")) {
// Amazon Nova Lite
modelInfo = {
maxTokens: 4096,
contextWindow: 128_000,
supportsPromptCache: false,
supportsImages: false,
}
} else if (arnLower.includes("nova-pro")) {
// Amazon Nova Pro
modelInfo = {
Expand Down
Loading