Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 241 additions & 0 deletions src/api/providers/__tests__/bedrock-error-handling.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
// npx vitest run src/api/providers/__tests__/bedrock-error-handling.spec.ts

import { vitest, describe, it, expect, beforeEach } from "vitest"
import { AwsBedrockHandler } from "../bedrock"
import { ApiHandlerOptions } from "../../../shared/api"
import { logger } from "../../../utils/logging"

// Mock the logger
vitest.mock("../../../utils/logging", () => ({
logger: {
debug: vitest.fn(),
info: vitest.fn(),
warn: vitest.fn(),
error: vitest.fn(),
fatal: vitest.fn(),
child: vitest.fn().mockReturnValue({
debug: vitest.fn(),
info: vitest.fn(),
warn: vitest.fn(),
error: vitest.fn(),
fatal: vitest.fn(),
}),
},
}))

// Mock AWS SDK
vitest.mock("@aws-sdk/client-bedrock-runtime", () => {
const mockSend = vitest.fn()
const mockConverseCommand = vitest.fn()

const MockBedrockRuntimeClient = class {
public config: any
public send: any

constructor(config: { region?: string }) {
this.config = config
this.send = mockSend
}
}

return {
BedrockRuntimeClient: MockBedrockRuntimeClient,
ConverseCommand: mockConverseCommand,
ConverseStreamCommand: vitest.fn(),
// Export the mock functions for test access
__mockSend: mockSend,
__mockConverseCommand: mockConverseCommand,
}
})

describe("Bedrock Error Handling", () => {
let handler: AwsBedrockHandler

beforeEach(() => {
const defaultOptions: ApiHandlerOptions = {
apiModelId: "anthropic.claude-3-sonnet-20240229-v1:0",
awsRegion: "us-east-1",
}
handler = new AwsBedrockHandler(defaultOptions)
})

describe("getErrorType", () => {
it("should identify throttling errors by HTTP status code 429", () => {
const error = new Error("Request failed") as any
error.status = 429

const errorType = (handler as any).getErrorType(error)
expect(errorType).toBe("THROTTLING")
})

it("should identify throttling errors by AWS metadata httpStatusCode 429", () => {
const error = new Error("Request failed") as any
error.$metadata = { httpStatusCode: 429 }

const errorType = (handler as any).getErrorType(error)
expect(errorType).toBe("THROTTLING")
})

it("should identify throttling errors by ThrottlingException name", () => {
const error = new Error("Request failed") as any
error.name = "ThrottlingException"

const errorType = (handler as any).getErrorType(error)
expect(errorType).toBe("THROTTLING")
})

it("should identify throttling errors by __type ThrottlingException", () => {
const error = new Error("Request failed") as any
error.__type = "ThrottlingException"

const errorType = (handler as any).getErrorType(error)
expect(errorType).toBe("THROTTLING")
})

it("should identify throttling errors by message pattern 'unable to process your request'", () => {
const error = new Error("Bedrock is unable to process your request")

const errorType = (handler as any).getErrorType(error)
expect(errorType).toBe("THROTTLING")
})

it("should identify throttling errors by message pattern 'too many tokens'", () => {
const error = new Error("Too many tokens in request")

const errorType = (handler as any).getErrorType(error)
expect(errorType).toBe("THROTTLING")
})

it("should identify throttling errors by message pattern 'please wait'", () => {
const error = new Error("Please wait before making another request")

const errorType = (handler as any).getErrorType(error)
expect(errorType).toBe("THROTTLING")
})

it("should identify throttling errors by message pattern 'service is temporarily unavailable'", () => {
const error = new Error("Service is temporarily unavailable")

const errorType = (handler as any).getErrorType(error)
expect(errorType).toBe("THROTTLING")
})

it("should identify traditional throttling patterns", () => {
const throttleError = new Error("Request was throttled")
const rateLimitError = new Error("Rate limit exceeded")
const limitError = new Error("Limit reached")

expect((handler as any).getErrorType(throttleError)).toBe("THROTTLING")
expect((handler as any).getErrorType(rateLimitError)).toBe("THROTTLING")
expect((handler as any).getErrorType(limitError)).toBe("THROTTLING")
})

it("should return GENERIC for non-throttling errors", () => {
const error = new Error("Some other error")

const errorType = (handler as any).getErrorType(error)
expect(errorType).toBe("GENERIC")
})

it("should return GENERIC for non-Error objects", () => {
const error = "string error"

const errorType = (handler as any).getErrorType(error)
expect(errorType).toBe("GENERIC")
})

it("should identify access denied errors", () => {
const error = new Error("Access denied to model")

const errorType = (handler as any).getErrorType(error)
expect(errorType).toBe("ACCESS_DENIED")
})

it("should identify validation errors", () => {
const error = new Error("Input tag validation failed")

const errorType = (handler as any).getErrorType(error)
expect(errorType).toBe("VALIDATION_ERROR")
})
})

describe("handleBedrockError", () => {
it("should format throttling error messages with guidance", () => {
const error = new Error("Bedrock is unable to process your request")

const result = (handler as any).handleBedrockError(error, false)
expect(result).toContain("Request was throttled or rate limited")
expect(result).toContain("Reducing the frequency of requests")
})

it("should return streaming chunks for streaming context", () => {
const error = new Error("Some error")

const result = (handler as any).handleBedrockError(error, true)
expect(Array.isArray(result)).toBe(true)
expect(result[0]).toHaveProperty("type", "text")
expect(result[1]).toHaveProperty("type", "usage")
})

it("should return string for non-streaming context", () => {
const error = new Error("Some error")

const result = (handler as any).handleBedrockError(error, false)
expect(typeof result).toBe("string")
expect(result).toContain("Bedrock completion error:")
})
})

describe("Error handling in createMessage and completePrompt", () => {
it("should re-throw throttling errors in createMessage for retry handling", async () => {
const throttlingError = new Error("Bedrock is unable to process your request")

// Mock the AWS SDK to throw a throttling error
const mockModule = await import("@aws-sdk/client-bedrock-runtime")
;(mockModule as any).__mockSend.mockRejectedValueOnce(throttlingError)

const generator = handler.createMessage("test", [])

// The throttling error should be re-thrown, not handled as a streaming error
await expect(generator.next()).rejects.toThrow("Bedrock is unable to process your request")
})

it("should re-throw throttling errors in completePrompt for retry handling", async () => {
const throttlingError = new Error("Too many tokens") as any
throttlingError.status = 429

// Mock the AWS SDK to throw a throttling error
const mockModule = await import("@aws-sdk/client-bedrock-runtime")
;(mockModule as any).__mockSend.mockRejectedValueOnce(throttlingError)

// The throttling error should be re-thrown, not handled as a completion error
await expect(handler.completePrompt("test")).rejects.toThrow("Too many tokens")
})

it("should handle non-throttling errors normally in createMessage", async () => {
const genericError = new Error("Some other error")

// Mock the AWS SDK to throw a generic error
const mockModule = await import("@aws-sdk/client-bedrock-runtime")
;(mockModule as any).__mockSend.mockRejectedValueOnce(genericError)

const generator = handler.createMessage("test", [])

// Generic errors should be handled as streaming errors, not re-thrown
const result = await generator.next()
expect(result.value).toHaveProperty("type", "text")
expect(result.value.text).toContain("Error:")
})

it("should handle non-throttling errors normally in completePrompt", async () => {
const genericError = new Error("Some other error")

// Mock the AWS SDK to throw a generic error
const mockModule = await import("@aws-sdk/client-bedrock-runtime")
;(mockModule as any).__mockSend.mockRejectedValueOnce(genericError)

// Generic errors should be handled as completion errors
await expect(handler.completePrompt("test")).rejects.toThrow("Bedrock completion error:")
})
})
})
49 changes: 46 additions & 3 deletions src/api/providers/bedrock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,20 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
// Clear timeout on error
clearTimeout(timeoutId)

// Use the extracted error handling method for all errors
// Check if this is a throttling error that should trigger retry mechanism
const errorType = this.getErrorType(error)

if (errorType === "THROTTLING") {
// For throttling errors, we want to re-throw immediately to let the retry mechanism in Task.ts handle it
// This ensures throttling errors during streaming get the same retry treatment as first-chunk errors
if (error instanceof Error) {
throw error
} else {
throw new Error("Throttling error occurred during streaming")
}
}

// Use the extracted error handling method for all other errors
const errorChunks = this.handleBedrockError(error, true) // true for streaming context
// Yield each chunk individually to ensure type compatibility
for (const chunk of errorChunks) {
Expand Down Expand Up @@ -634,7 +647,19 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
}
return ""
} catch (error) {
// Use the extracted error handling method for all errors
// Check if this is a throttling error that should be re-thrown for retry handling
const errorType = this.getErrorType(error)

if (errorType === "THROTTLING") {
// For throttling errors, re-throw the original error to allow retry mechanisms to handle it
if (error instanceof Error) {
throw error
} else {
throw new Error("Throttling error occurred")
}
}

// Use the extracted error handling method for all other errors
const errorResult = this.handleBedrockError(error, false) // false for non-streaming context
// Since we're in a non-streaming context, we know the result is a string
const errorMessage = errorResult as string
Expand Down Expand Up @@ -1035,7 +1060,15 @@ Please verify:
logLevel: "error",
},
THROTTLING: {
patterns: ["throttl", "rate", "limit"],
patterns: [
"throttl",
"rate",
"limit",
"unable to process your request",
"too many tokens",
"please wait",
"service is temporarily unavailable",
],
messageTemplate: `Request was throttled or rate limited. Please try:
1. Reducing the frequency of requests
2. If using a provisioned model, check its throughput settings
Expand Down Expand Up @@ -1119,6 +1152,16 @@ Please check:
return "GENERIC"
}

// Check for HTTP 429 status code (Too Many Requests)
if ((error as any).status === 429 || (error as any).$metadata?.httpStatusCode === 429) {
return "THROTTLING"
}

// Check for AWS Bedrock specific throttling exception names
if ((error as any).name === "ThrottlingException" || (error as any).__type === "ThrottlingException") {
return "THROTTLING"
}

const errorMessage = error.message.toLowerCase()
const errorName = error.name.toLowerCase()

Expand Down
Loading