Skip to content

Commit a4590ec

Browse files
committed
count tokens endpoint for bedrock
1 parent 39542fc commit a4590ec

File tree

4 files changed

+86
-20
lines changed

4 files changed

+86
-20
lines changed

src/providers/bedrock/api.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,9 @@ const BedrockAPIConfig: BedrockAPIConfigInterface = {
280280
case 'cancelFinetune': {
281281
return `/model-customization-jobs/${jobId}/stop`;
282282
}
283+
case 'messagesCountTokens': {
284+
return `/model/${uriEncodedModel}/count-tokens`;
285+
}
283286
default:
284287
return '';
285288
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import { ProviderConfig } from '../types';
2+
import { BedrockMessagesParams } from './types';
3+
import { transformUsingProviderConfig } from '../../services/transformToProviderRequest';
4+
import { BedrockConverseMessagesConfig } from './messages';
5+
import { Params } from '../../types/requestBody';
6+
import { BEDROCK } from '../../globals';
7+
import { BedrockErrorResponseTransform } from './chatComplete';
8+
import { generateInvalidProviderResponseError } from '../utils';
9+
10+
// https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CountTokens.html#API_runtime_CountTokens_RequestSyntax
11+
export const BedrockConverseMessageCountTokensConfig: ProviderConfig = {
12+
messages: {
13+
param: 'input',
14+
required: true,
15+
transform: (params: BedrockMessagesParams) => {
16+
return {
17+
converse: transformUsingProviderConfig(
18+
BedrockConverseMessagesConfig,
19+
params as Params
20+
),
21+
};
22+
},
23+
},
24+
};
25+
26+
// https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CountTokens.html#API_runtime_CountTokens_ResponseSyntax
27+
export const BedrockConverseMessageCountTokensResponseTransform = (
28+
response: any,
29+
responseStatus: number
30+
) => {
31+
if (responseStatus !== 200 && 'error' in response) {
32+
return (
33+
BedrockErrorResponseTransform(response) ||
34+
generateInvalidProviderResponseError(response, BEDROCK)
35+
);
36+
}
37+
38+
if ('inputTokens' in response) {
39+
return {
40+
input_tokens: response.inputTokens,
41+
};
42+
}
43+
44+
return generateInvalidProviderResponseError(response, BEDROCK);
45+
};

src/providers/bedrock/index.ts

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ import {
8080
BedrockConverseMessagesStreamChunkTransform,
8181
BedrockMessagesResponseTransform,
8282
} from './messages';
83+
import {
84+
BedrockConverseMessageCountTokensConfig,
85+
BedrockConverseMessageCountTokensResponseTransform,
86+
} from './countTokens';
8387

8488
const BedrockConfig: ProviderConfigs = {
8589
api: BedrockAPIConfig,
@@ -110,8 +114,6 @@ const BedrockConfig: ProviderConfigs = {
110114
responseTransforms: {
111115
'stream-complete': BedrockAnthropicCompleteStreamChunkTransform,
112116
complete: BedrockAnthropicCompleteResponseTransform,
113-
messages: BedrockMessagesResponseTransform,
114-
'stream-messages': BedrockConverseMessagesStreamChunkTransform,
115117
},
116118
};
117119
break;
@@ -201,24 +203,40 @@ const BedrockConfig: ProviderConfigs = {
201203
},
202204
};
203205
}
204-
if (!config.chatComplete) {
205-
config.chatComplete = BedrockConverseChatCompleteConfig;
206-
}
207-
if (!config.messages) {
208-
config.messages = BedrockConverseMessagesConfig;
209-
}
210-
if (!config.responseTransforms?.['stream-chatComplete']) {
211-
config.responseTransforms = {
212-
...(config.responseTransforms ?? {}),
213-
'stream-chatComplete': BedrockChatCompleteStreamChunkTransform,
214-
};
215-
}
216-
if (!config.responseTransforms?.chatComplete) {
217-
config.responseTransforms = {
218-
...(config.responseTransforms ?? {}),
206+
207+
// defaults
208+
config = {
209+
...config,
210+
...(!config.chatComplete && {
211+
chatComplete: BedrockConverseChatCompleteConfig,
212+
}),
213+
...(!config.messages && {
214+
messages: BedrockConverseMessagesConfig,
215+
}),
216+
...(!config.messagesCountTokens && {
217+
messagesCountTokens: BedrockConverseMessageCountTokensConfig,
218+
}),
219+
};
220+
221+
config.responseTransforms = {
222+
...(config.responseTransforms ?? {}),
223+
...(!config.responseTransforms?.chatComplete && {
219224
chatComplete: BedrockChatCompleteResponseTransform,
220-
};
221-
}
225+
}),
226+
...(!config.responseTransforms?.['stream-chatComplete'] && {
227+
'stream-chatComplete': BedrockChatCompleteStreamChunkTransform,
228+
}),
229+
...(!config.responseTransforms?.messages && {
230+
messages: BedrockMessagesResponseTransform,
231+
}),
232+
...(!config.responseTransforms?.['stream-messages'] && {
233+
'stream-messages': BedrockConverseMessagesStreamChunkTransform,
234+
}),
235+
...(!config.responseTransforms?.messagesCountTokens && {
236+
messagesCountTokens:
237+
BedrockConverseMessageCountTokensResponseTransform,
238+
}),
239+
};
222240
}
223241

224242
const commonResponseTransforms = {

src/services/transformToProviderRequest.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ const getValue = (configParam: string, params: Params, paramConfig: any) => {
7070
export const transformUsingProviderConfig = (
7171
providerConfig: ProviderConfig,
7272
params: Params,
73-
providerOptions: Options
73+
providerOptions?: Options
7474
) => {
7575
const transformedRequest: { [key: string]: any } = {};
7676

0 commit comments

Comments
 (0)