Skip to content

Commit 1be51a4

Browse files
authored
Merge pull request #60 from Premshay/main
feat(api): unify Bedrock provider using Runtime API
2 parents 49067c0 + e93f590 commit 1be51a4

File tree

7 files changed

+1022
-106
lines changed

7 files changed

+1022
-106
lines changed

package-lock.json

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@
192192
},
193193
"dependencies": {
194194
"@anthropic-ai/bedrock-sdk": "^0.10.2",
195+
"@aws-sdk/client-bedrock-runtime": "^3.706.0",
195196
"@anthropic-ai/sdk": "^0.26.0",
196197
"@anthropic-ai/vertex-sdk": "^0.4.1",
197198
"@google/generative-ai": "^0.18.0",
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
import { AwsBedrockHandler } from '../bedrock'
2+
import { ApiHandlerOptions, ModelInfo } from '../../../shared/api'
3+
import { Anthropic } from '@anthropic-ai/sdk'
4+
import { StreamEvent } from '../bedrock'
5+
6+
// Simplified mock for BedrockRuntimeClient
7+
class MockBedrockRuntimeClient {
8+
private _region: string
9+
private mockStream: StreamEvent[] = []
10+
11+
constructor(config: { region: string }) {
12+
this._region = config.region
13+
}
14+
15+
async send(command: any): Promise<{ stream: AsyncIterableIterator<StreamEvent> }> {
16+
return {
17+
stream: this.createMockStream()
18+
}
19+
}
20+
21+
private createMockStream(): AsyncIterableIterator<StreamEvent> {
22+
const self = this;
23+
return {
24+
async *[Symbol.asyncIterator]() {
25+
for (const event of self.mockStream) {
26+
yield event;
27+
}
28+
},
29+
next: async () => {
30+
const value = this.mockStream.shift();
31+
return value ? { value, done: false } : { value: undefined, done: true };
32+
},
33+
return: async () => ({ value: undefined, done: true }),
34+
throw: async (e) => { throw e; }
35+
};
36+
}
37+
38+
setMockStream(stream: StreamEvent[]) {
39+
this.mockStream = stream;
40+
}
41+
42+
get config() {
43+
return { region: this._region };
44+
}
45+
}
46+
47+
describe('AwsBedrockHandler', () => {
48+
const mockOptions: ApiHandlerOptions = {
49+
awsRegion: 'us-east-1',
50+
awsAccessKey: 'mock-access-key',
51+
awsSecretKey: 'mock-secret-key',
52+
apiModelId: 'anthropic.claude-v2',
53+
}
54+
55+
// Override the BedrockRuntimeClient creation in the constructor
56+
class TestAwsBedrockHandler extends AwsBedrockHandler {
57+
constructor(options: ApiHandlerOptions, mockClient?: MockBedrockRuntimeClient) {
58+
super(options)
59+
if (mockClient) {
60+
// Force type casting to bypass strict type checking
61+
(this as any)['client'] = mockClient
62+
}
63+
}
64+
}
65+
66+
test('constructor initializes with correct AWS credentials', () => {
67+
const mockClient = new MockBedrockRuntimeClient({
68+
region: 'us-east-1'
69+
})
70+
71+
const handler = new TestAwsBedrockHandler(mockOptions, mockClient)
72+
73+
// Verify that the client is created with the correct configuration
74+
expect(handler['client']).toBeDefined()
75+
expect(handler['client'].config.region).toBe('us-east-1')
76+
})
77+
78+
test('getModel returns correct model info', () => {
79+
const mockClient = new MockBedrockRuntimeClient({
80+
region: 'us-east-1'
81+
})
82+
83+
const handler = new TestAwsBedrockHandler(mockOptions, mockClient)
84+
const result = handler.getModel()
85+
86+
expect(result).toEqual({
87+
id: 'anthropic.claude-v2',
88+
info: {
89+
maxTokens: 5000,
90+
contextWindow: 128_000,
91+
supportsPromptCache: false
92+
}
93+
})
94+
})
95+
96+
test('createMessage handles successful stream events', async () => {
97+
const mockClient = new MockBedrockRuntimeClient({
98+
region: 'us-east-1'
99+
})
100+
101+
// Mock stream events
102+
const mockStreamEvents: StreamEvent[] = [
103+
{
104+
metadata: {
105+
usage: {
106+
inputTokens: 50,
107+
outputTokens: 100
108+
}
109+
}
110+
},
111+
{
112+
contentBlockStart: {
113+
start: {
114+
text: 'Hello'
115+
}
116+
}
117+
},
118+
{
119+
contentBlockDelta: {
120+
delta: {
121+
text: ' world'
122+
}
123+
}
124+
},
125+
{
126+
messageStop: {
127+
stopReason: 'end_turn'
128+
}
129+
}
130+
]
131+
132+
mockClient.setMockStream(mockStreamEvents)
133+
134+
const handler = new TestAwsBedrockHandler(mockOptions, mockClient)
135+
136+
const systemPrompt = 'You are a helpful assistant'
137+
const messages: Anthropic.Messages.MessageParam[] = [
138+
{ role: 'user', content: 'Say hello' }
139+
]
140+
141+
const generator = handler.createMessage(systemPrompt, messages)
142+
const chunks = []
143+
144+
for await (const chunk of generator) {
145+
chunks.push(chunk)
146+
}
147+
148+
// Verify the chunks match expected stream events
149+
expect(chunks).toHaveLength(3)
150+
expect(chunks[0]).toEqual({
151+
type: 'usage',
152+
inputTokens: 50,
153+
outputTokens: 100
154+
})
155+
expect(chunks[1]).toEqual({
156+
type: 'text',
157+
text: 'Hello'
158+
})
159+
expect(chunks[2]).toEqual({
160+
type: 'text',
161+
text: ' world'
162+
})
163+
})
164+
165+
test('createMessage handles error scenarios', async () => {
166+
const mockClient = new MockBedrockRuntimeClient({
167+
region: 'us-east-1'
168+
})
169+
170+
// Simulate an error by overriding the send method
171+
mockClient.send = () => {
172+
throw new Error('API request failed')
173+
}
174+
175+
const handler = new TestAwsBedrockHandler(mockOptions, mockClient)
176+
177+
const systemPrompt = 'You are a helpful assistant'
178+
const messages: Anthropic.Messages.MessageParam[] = [
179+
{ role: 'user', content: 'Cause an error' }
180+
]
181+
182+
await expect(async () => {
183+
const generator = handler.createMessage(systemPrompt, messages)
184+
const chunks = []
185+
186+
for await (const chunk of generator) {
187+
chunks.push(chunk)
188+
}
189+
}).rejects.toThrow('API request failed')
190+
})
191+
})

0 commit comments

Comments
 (0)