Skip to content

Commit 71f0209

Browse files
authored
Merge pull request #1293 from narengogi/feat/unified-count-tokens-endpoint
add a new unified route for count_tokens endpoint
2 parents 26930be + a5b9af0 commit 71f0209

File tree

12 files changed

+173
-21
lines changed

12 files changed

+173
-21
lines changed
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import { RouterError } from '../errors/RouterError';
2+
import {
3+
constructConfigFromRequestHeaders,
4+
tryTargetsRecursively,
5+
} from './handlerUtils';
6+
import { Context } from 'hono';
7+
8+
/**
9+
* Handles the '/messages/count_tokens' API request by selecting the appropriate provider(s) and making the request to them.
10+
*
11+
* @param {Context} c - The Cloudflare Worker context.
12+
* @returns {Promise<Response>} - The response from the provider.
13+
* @throws Will throw an error if no provider options can be determined or if the request to the provider(s) fails.
14+
* @throws Will throw an 500 error if the handler fails due to some reasons
15+
*/
16+
export async function messagesCountTokensHandler(
17+
c: Context
18+
): Promise<Response> {
19+
try {
20+
let request = await c.req.json();
21+
let requestHeaders = Object.fromEntries(c.req.raw.headers);
22+
const camelCaseConfig = constructConfigFromRequestHeaders(requestHeaders);
23+
const tryTargetsResponse = await tryTargetsRecursively(
24+
c,
25+
camelCaseConfig ?? {},
26+
request,
27+
requestHeaders,
28+
'messagesCountTokens',
29+
'POST',
30+
'config'
31+
);
32+
33+
return tryTargetsResponse;
34+
} catch (err: any) {
35+
console.log('messagesCountTokens error', err.message);
36+
let statusCode = 500;
37+
let errorMessage = 'Something went wrong';
38+
39+
if (err instanceof RouterError) {
40+
statusCode = 400;
41+
errorMessage = err.message;
42+
}
43+
44+
return new Response(
45+
JSON.stringify({
46+
status: 'failure',
47+
message: errorMessage,
48+
}),
49+
{
50+
status: statusCode,
51+
headers: {
52+
'content-type': 'application/json',
53+
},
54+
}
55+
);
56+
}
57+
}

src/index.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import { messagesHandler } from './handlers/messagesHandler';
3636
// Config
3737
import conf from '../conf.json';
3838
import modelResponsesHandler from './handlers/modelResponsesHandler';
39+
import { messagesCountTokensHandler } from './handlers/messagesCountTokensHandler';
3940

4041
// Create a new Hono server instance
4142
const app = new Hono();
@@ -126,6 +127,12 @@ app.onError((err, c) => {
126127
*/
127128
app.post('/v1/messages', requestValidator, messagesHandler);
128129

130+
app.post(
131+
'/v1/messages/count_tokens',
132+
requestValidator,
133+
messagesCountTokensHandler
134+
);
135+
129136
/**
130137
* POST route for '/v1/chat/completions'.
131138
* Handles requests by passing them to the chatCompletionsHandler.

src/providers/anthropic/api.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ const AnthropicAPIConfig: ProviderAPIConfig = {
3131
return '/messages';
3232
case 'messages':
3333
return '/messages';
34+
case 'messagesCountTokens':
35+
return '/messages/count_tokens';
3436
default:
3537
return '';
3638
}

src/providers/anthropic/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ const AnthropicConfig: ProviderConfigs = {
1919
complete: AnthropicCompleteConfig,
2020
chatComplete: AnthropicChatCompleteConfig,
2121
messages: AnthropicMessagesConfig,
22+
messagesCountTokens: AnthropicMessagesConfig,
2223
api: AnthropicAPIConfig,
2324
responseTransforms: {
2425
'stream-complete': AnthropicCompleteStreamChunkTransform,

src/providers/bedrock/api.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,9 @@ const BedrockAPIConfig: BedrockAPIConfigInterface = {
280280
case 'cancelFinetune': {
281281
return `/model-customization-jobs/${jobId}/stop`;
282282
}
283+
case 'messagesCountTokens': {
284+
return `/model/${uriEncodedModel}/count-tokens`;
285+
}
283286
default:
284287
return '';
285288
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import { ProviderConfig } from '../types';
2+
import { BedrockMessagesParams } from './types';
3+
import { transformUsingProviderConfig } from '../../services/transformToProviderRequest';
4+
import { BedrockConverseMessagesConfig } from './messages';
5+
import { Params } from '../../types/requestBody';
6+
import { BEDROCK } from '../../globals';
7+
import { BedrockErrorResponseTransform } from './chatComplete';
8+
import { generateInvalidProviderResponseError } from '../utils';
9+
10+
// https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CountTokens.html#API_runtime_CountTokens_RequestSyntax
11+
export const BedrockConverseMessageCountTokensConfig: ProviderConfig = {
12+
messages: {
13+
param: 'input',
14+
required: true,
15+
transform: (params: BedrockMessagesParams) => {
16+
return {
17+
converse: transformUsingProviderConfig(
18+
BedrockConverseMessagesConfig,
19+
params as Params
20+
),
21+
};
22+
},
23+
},
24+
};
25+
26+
// https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CountTokens.html#API_runtime_CountTokens_ResponseSyntax
27+
export const BedrockConverseMessageCountTokensResponseTransform = (
28+
response: any,
29+
responseStatus: number
30+
) => {
31+
if (responseStatus !== 200 && 'error' in response) {
32+
return (
33+
BedrockErrorResponseTransform(response) ||
34+
generateInvalidProviderResponseError(response, BEDROCK)
35+
);
36+
}
37+
38+
if ('inputTokens' in response) {
39+
return {
40+
input_tokens: response.inputTokens,
41+
};
42+
}
43+
44+
return generateInvalidProviderResponseError(response, BEDROCK);
45+
};

src/providers/bedrock/index.ts

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ import {
8080
BedrockConverseMessagesStreamChunkTransform,
8181
BedrockMessagesResponseTransform,
8282
} from './messages';
83+
import {
84+
BedrockConverseMessageCountTokensConfig,
85+
BedrockConverseMessageCountTokensResponseTransform,
86+
} from './countTokens';
8387

8488
const BedrockConfig: ProviderConfigs = {
8589
api: BedrockAPIConfig,
@@ -110,8 +114,6 @@ const BedrockConfig: ProviderConfigs = {
110114
responseTransforms: {
111115
'stream-complete': BedrockAnthropicCompleteStreamChunkTransform,
112116
complete: BedrockAnthropicCompleteResponseTransform,
113-
messages: BedrockMessagesResponseTransform,
114-
'stream-messages': BedrockConverseMessagesStreamChunkTransform,
115117
},
116118
};
117119
break;
@@ -201,24 +203,40 @@ const BedrockConfig: ProviderConfigs = {
201203
},
202204
};
203205
}
204-
if (!config.chatComplete) {
205-
config.chatComplete = BedrockConverseChatCompleteConfig;
206-
}
207-
if (!config.messages) {
208-
config.messages = BedrockConverseMessagesConfig;
209-
}
210-
if (!config.responseTransforms?.['stream-chatComplete']) {
211-
config.responseTransforms = {
212-
...(config.responseTransforms ?? {}),
213-
'stream-chatComplete': BedrockChatCompleteStreamChunkTransform,
214-
};
215-
}
216-
if (!config.responseTransforms?.chatComplete) {
217-
config.responseTransforms = {
218-
...(config.responseTransforms ?? {}),
206+
207+
// defaults
208+
config = {
209+
...config,
210+
...(!config.chatComplete && {
211+
chatComplete: BedrockConverseChatCompleteConfig,
212+
}),
213+
...(!config.messages && {
214+
messages: BedrockConverseMessagesConfig,
215+
}),
216+
...(!config.messagesCountTokens && {
217+
messagesCountTokens: BedrockConverseMessageCountTokensConfig,
218+
}),
219+
};
220+
221+
config.responseTransforms = {
222+
...(config.responseTransforms ?? {}),
223+
...(!config.responseTransforms?.chatComplete && {
219224
chatComplete: BedrockChatCompleteResponseTransform,
220-
};
221-
}
225+
}),
226+
...(!config.responseTransforms?.['stream-chatComplete'] && {
227+
'stream-chatComplete': BedrockChatCompleteStreamChunkTransform,
228+
}),
229+
...(!config.responseTransforms?.messages && {
230+
messages: BedrockMessagesResponseTransform,
231+
}),
232+
...(!config.responseTransforms?.['stream-messages'] && {
233+
'stream-messages': BedrockConverseMessagesStreamChunkTransform,
234+
}),
235+
...(!config.responseTransforms?.messagesCountTokens && {
236+
messagesCountTokens:
237+
BedrockConverseMessageCountTokensResponseTransform,
238+
}),
239+
};
222240
}
223241

224242
const commonResponseTransforms = {

src/providers/google-vertex-ai/api.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ export const GoogleApiConfig: ProviderAPIConfig = {
178178
mappedFn === 'stream-messages'
179179
) {
180180
return `${projectRoute}/publishers/${provider}/models/${model}:streamRawPredict`;
181+
} else if (mappedFn === 'messagesCountTokens') {
182+
return `${projectRoute}/publishers/${provider}/models/count-tokens:rawPredict`;
181183
}
182184
}
183185

src/providers/google-vertex-ai/index.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ import {
5151
VertexAnthropicMessagesConfig,
5252
VertexAnthropicMessagesResponseTransform,
5353
} from './messages';
54+
import { VertexAnthropicMessagesCountTokensConfig } from './messagesCountTokens';
5455
import {
5556
GetMistralAIChatCompleteResponseTransform,
5657
GetMistralAIChatCompleteStreamChunkTransform,
@@ -122,6 +123,7 @@ const VertexConfig: ProviderConfigs = {
122123
createBatch: GoogleBatchCreateConfig,
123124
createFinetune: baseConfig.createFinetune,
124125
messages: VertexAnthropicMessagesConfig,
126+
messagesCountTokens: VertexAnthropicMessagesCountTokensConfig,
125127
responseTransforms: {
126128
'stream-chatComplete':
127129
VertexAnthropicChatCompleteStreamChunkTransform,
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import { MessageCreateParamsBase } from '../../types/MessagesRequest';
2+
import { getMessagesConfig } from '../anthropic-base/messages';
3+
4+
export const VertexAnthropicMessagesCountTokensConfig = {
5+
...getMessagesConfig({}),
6+
model: {
7+
param: 'model',
8+
required: true,
9+
transform: (params: MessageCreateParamsBase) => {
10+
let model = params.model ?? '';
11+
return model.replace('anthropic.', '');
12+
},
13+
},
14+
};

0 commit comments

Comments
 (0)