Skip to content

Commit 34829b3

Browse files
committed
feat: ✨ Add gpt-oss to the bedrock provider
1 parent 1220021 commit 34829b3

File tree

2 files changed

+158
-44
lines changed

2 files changed

+158
-44
lines changed

core/llm/llms/Bedrock.ts

Lines changed: 157 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ import {
22
BedrockRuntimeClient,
33
ContentBlock,
44
ConversationRole,
5+
ConverseCommand,
56
ConverseStreamCommand,
6-
ConverseStreamCommandOutput,
77
ImageFormat,
88
InvokeModelCommand,
99
Message,
10-
ToolConfiguration,
10+
ToolConfiguration
1111
} from "@aws-sdk/client-bedrock-runtime";
1212
import { fromNodeProviderChain } from "@aws-sdk/credential-providers";
1313

@@ -91,47 +91,33 @@ class Bedrock extends BaseLLM {
9191
signal: AbortSignal,
9292
options: CompletionOptions,
9393
): AsyncGenerator<ChatMessage> {
94-
const credentials = await this._getCredentials();
95-
const client = new BedrockRuntimeClient({
96-
region: this.region,
97-
endpoint: this.apiBase,
98-
credentials: {
99-
accessKeyId: credentials.accessKeyId,
100-
secretAccessKey: credentials.secretAccessKey,
101-
sessionToken: credentials.sessionToken || "",
102-
},
103-
});
104-
105-
let config_headers =
106-
this.requestOptions && this.requestOptions.headers
107-
? this.requestOptions.headers
108-
: {};
109-
// AWS SigV4 requires strict canonicalization of headers.
110-
// DO NOT USE "_" in your header name. It will return an error like below.
111-
// "The request signature we calculated does not match the signature you provided."
94+
if (options.stream !== false) {
95+
yield* this._streamChatStreaming(messages, signal, options);
96+
} else {
97+
yield* this._streamChatNonStreaming(messages, signal, options);
98+
}
99+
}
112100

113-
client.middlewareStack.add(
114-
(next) => async (args: any) => {
115-
args.request.headers = {
116-
...args.request.headers,
117-
...config_headers,
118-
};
119-
return next(args);
120-
},
121-
{
122-
step: "build",
123-
},
124-
);
101+
/**
102+
* Handles streaming chat using ConverseStreamCommand
103+
*/
104+
private async *_streamChatStreaming(
105+
messages: ChatMessage[],
106+
signal: AbortSignal,
107+
options: CompletionOptions,
108+
): AsyncGenerator<ChatMessage> {
109+
const client = await this._createBedrockClient();
110+
this._addClientMiddleware(client);
125111

126112
const input = this._generateConverseInput(messages, {
127113
...options,
128114
stream: true,
129115
});
130-
const command = new ConverseStreamCommand(input);
131116

132-
const response = (await client.send(command, {
117+
const command = new ConverseStreamCommand(input);
118+
const response = await client.send(command, {
133119
abortSignal: signal,
134-
})) as ConverseStreamCommandOutput;
120+
});
135121

136122
if (!response?.stream) {
137123
throw new Error("No stream received from Bedrock API");
@@ -158,17 +144,17 @@ class Bedrock extends BaseLLM {
158144
role: "assistant",
159145
content: chunk.contentBlockDelta.delta.text,
160146
};
161-
continue;
147+
continue; // Continue parsing the stream
162148
}
163149

164-
// Handle text content
150+
// Handle reasoning text content
165151
if ((chunk.contentBlockDelta.delta as any).reasoningContent?.text) {
166152
yield {
167153
role: "thinking",
168154
content: (chunk.contentBlockDelta.delta as any).reasoningContent
169155
.text,
170156
};
171-
continue;
157+
continue; // Continue parsing the stream
172158
}
173159

174160
// Handle signature for thinking
@@ -178,7 +164,7 @@ class Bedrock extends BaseLLM {
178164
content: "",
179165
signature: delta.reasoningContent.signature,
180166
};
181-
continue;
167+
continue; // Continue parsing the stream
182168
}
183169

184170
// Handle redacted thinking
@@ -188,7 +174,7 @@ class Bedrock extends BaseLLM {
188174
content: "",
189175
redactedThinking: delta.redactedReasoning.data,
190176
};
191-
continue;
177+
continue; // Continue parsing the stream
192178
}
193179

194180
if (
@@ -201,7 +187,7 @@ class Bedrock extends BaseLLM {
201187
}
202188
this._currentToolResponse.input +=
203189
chunk.contentBlockDelta.delta.toolUse.input;
204-
continue;
190+
continue; // Continue parsing the stream
205191
}
206192
}
207193

@@ -213,7 +199,7 @@ class Bedrock extends BaseLLM {
213199
content: "",
214200
redactedThinking: start.redactedReasoning.data,
215201
};
216-
continue;
202+
continue; // Continue parsing the stream
217203
}
218204

219205
const toolUse = chunk.contentBlockStart.start.toolUse;
@@ -224,7 +210,7 @@ class Bedrock extends BaseLLM {
224210
input: "",
225211
};
226212
}
227-
continue;
213+
continue; // Continue parsing the stream
228214
}
229215

230216
if (chunk.contentBlockStop) {
@@ -245,7 +231,7 @@ class Bedrock extends BaseLLM {
245231
};
246232
this._currentToolResponse = null;
247233
}
248-
continue;
234+
continue; // Continue parsing the stream
249235
}
250236
}
251237
} catch (error: unknown) {
@@ -255,6 +241,133 @@ class Bedrock extends BaseLLM {
255241
}
256242
}
257243

244+
/**
245+
* Handles non-streaming chat using ConverseCommand
246+
*/
247+
private async *_streamChatNonStreaming(
248+
messages: ChatMessage[],
249+
signal: AbortSignal,
250+
options: CompletionOptions,
251+
): AsyncGenerator<ChatMessage> {
252+
const client = await this._createBedrockClient();
253+
this._addClientMiddleware(client);
254+
255+
const input = this._generateConverseInput(messages, {
256+
...options,
257+
stream: false,
258+
});
259+
260+
const command = new ConverseCommand(input);
261+
const response = await client.send(command, {
262+
abortSignal: signal,
263+
});
264+
265+
// Reset cache metrics for new request
266+
this._promptCachingMetrics = {
267+
cacheReadInputTokens: 0,
268+
cacheWriteInputTokens: 0,
269+
};
270+
271+
try {
272+
if (response.output?.message?.content) {
273+
for (const contentBlock of response.output.message.content) {
274+
if (contentBlock.text) {
275+
yield {
276+
role: "assistant",
277+
content: contentBlock.text,
278+
};
279+
}
280+
281+
if ((contentBlock as any).reasoningContent) {
282+
const reasoningContent = (contentBlock as any).reasoningContent;
283+
if (reasoningContent.reasoningText) {
284+
yield {
285+
role: "thinking",
286+
content: reasoningContent.reasoningText.text || "",
287+
signature: reasoningContent.reasoningText.signature,
288+
};
289+
}
290+
if (reasoningContent.redactedContent) {
291+
yield {
292+
role: "thinking",
293+
content: "",
294+
redactedThinking: reasoningContent.redactedContent,
295+
};
296+
}
297+
}
298+
299+
if (contentBlock.toolUse) {
300+
yield {
301+
role: "assistant",
302+
content: "",
303+
toolCalls: [
304+
{
305+
id: contentBlock.toolUse.toolUseId,
306+
type: "function",
307+
function: {
308+
name: contentBlock.toolUse.name,
309+
arguments: JSON.stringify(contentBlock.toolUse.input || {}),
310+
},
311+
},
312+
],
313+
};
314+
}
315+
}
316+
}
317+
318+
// Handle usage metadata if available
319+
if (response.usage) {
320+
console.log(`${JSON.stringify(response.usage)}`);
321+
}
322+
} catch (error: unknown) {
323+
// Clean up state and let the original error bubble up to the retry decorator
324+
this._currentToolResponse = null;
325+
throw error;
326+
}
327+
}
328+
329+
/**
330+
* Creates and configures a Bedrock Runtime Client
331+
*/
332+
private async _createBedrockClient(): Promise<BedrockRuntimeClient> {
333+
const credentials = await this._getCredentials();
334+
return new BedrockRuntimeClient({
335+
region: this.region,
336+
endpoint: this.apiBase,
337+
credentials: {
338+
accessKeyId: credentials.accessKeyId,
339+
secretAccessKey: credentials.secretAccessKey,
340+
sessionToken: credentials.sessionToken || "",
341+
},
342+
});
343+
}
344+
345+
/**
346+
* Adds middleware to the Bedrock client for custom headers
347+
*/
348+
private _addClientMiddleware(client: BedrockRuntimeClient): void {
349+
const config_headers =
350+
this.requestOptions && this.requestOptions.headers
351+
? this.requestOptions.headers
352+
: {};
353+
// AWS SigV4 requires strict canonicalization of headers.
354+
// DO NOT USE "_" in your header name. It will return an error like below.
355+
// "The request signature we calculated does not match the signature you provided."
356+
357+
client.middlewareStack.add(
358+
(next) => async (args: any) => {
359+
args.request.headers = {
360+
...args.request.headers,
361+
...config_headers,
362+
};
363+
return next(args);
364+
},
365+
{
366+
step: "build",
367+
},
368+
);
369+
}
370+
258371
/**
259372
* Generates the input payload for the Bedrock Converse API
260373
* @param messages - Array of chat messages

core/llm/toolSupport.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ export const PROVIDER_TOOL_SUPPORT: Record<string, (model: string) => boolean> =
124124
"nova-pro",
125125
"nova-micro",
126126
"nova-premier",
127+
"gpt-oss",
127128
].some((part) => model.toLowerCase().includes(part))
128129
) {
129130
return true;

0 commit comments

Comments
 (0)