Skip to content

Commit 4f5380c

Browse files
committed
cleanaup bedrock provider
add wrapper for environment variables with support for fetching from file path bedrock cleanup add conditional import
1 parent 6150f06 commit 4f5380c

File tree

15 files changed

+229
-55
lines changed

15 files changed

+229
-55
lines changed

src/handlers/handlerUtils.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,6 @@ export function constructConfigFromRequestHeaders(
856856
azureApiVersion: requestHeaders[`x-${POWERED_BY}-azure-api-version`],
857857
azureEndpointName: requestHeaders[`x-${POWERED_BY}-azure-endpoint-name`],
858858
azureFoundryUrl: requestHeaders[`x-${POWERED_BY}-azure-foundry-url`],
859-
azureExtraParams: requestHeaders[`x-${POWERED_BY}-azure-extra-params`],
860859
azureAdToken: requestHeaders[`x-${POWERED_BY}-azure-ad-token`],
861860
azureAuthMode: requestHeaders[`x-${POWERED_BY}-azure-auth-mode`],
862861
azureManagedClientId:
@@ -865,6 +864,7 @@ export function constructConfigFromRequestHeaders(
865864
azureEntraClientSecret:
866865
requestHeaders[`x-${POWERED_BY}-azure-entra-client-secret`],
867866
azureEntraTenantId: requestHeaders[`x-${POWERED_BY}-azure-entra-tenant-id`],
867+
azureExtraParameters: requestHeaders[`x-${POWERED_BY}-azure-extra-params`],
868868
};
869869

870870
const awsConfig = {

src/providers/azure-ai-inference/api.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@ const AzureAIInferenceAPI: ProviderAPIConfig = {
3939
headers: async ({ providerOptions, fn }) => {
4040
const {
4141
apiKey,
42-
azureExtraParams,
42+
azureExtraParameters,
4343
azureDeploymentName,
4444
azureAdToken,
4545
azureAuthMode,
4646
} = providerOptions;
4747

4848
const headers: Record<string, string> = {
49-
'extra-parameters': azureExtraParams ?? 'drop',
49+
'extra-parameters': azureExtraParameters ?? 'drop',
5050
...(azureDeploymentName && {
5151
'azureml-model-deployment': azureDeploymentName,
5252
}),

src/providers/bedrock/api.ts

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import { Context } from 'hono';
2-
import { Options } from '../../types/requestBody';
2+
import { Options, Params } from '../../types/requestBody';
33
import { endpointStrings, ProviderAPIConfig } from '../types';
44
import { bedrockInvokeModels } from './constants';
55
import {
6+
getAwsEndpointDomain,
67
generateAWSHeaders,
7-
getAssumedRoleCredentials,
88
getFoundationModelFromInferenceProfile,
99
providerAssumedRoleCredentials,
1010
} from './utils';
@@ -18,6 +18,7 @@ interface BedrockAPIConfigInterface extends Omit<ProviderAPIConfig, 'headers'> {
1818
transformedRequestBody: Record<string, any> | string;
1919
transformedRequestUrl: string;
2020
gatewayRequestBody?: Params;
21+
headers?: Record<string, string>;
2122
}) => Promise<Record<string, any>> | Record<string, any>;
2223
}
2324

@@ -66,7 +67,14 @@ const ENDPOINTS_TO_ROUTE_TO_S3 = [
6667
'initiateMultipartUpload',
6768
];
6869

69-
const getMethod = (fn: endpointStrings, transformedRequestUrl: string) => {
70+
const getMethod = (
71+
fn: endpointStrings,
72+
transformedRequestUrl: string,
73+
c: Context
74+
) => {
75+
if (fn === 'proxy') {
76+
return c.req.method;
77+
}
7078
if (fn === 'uploadFile') {
7179
const url = new URL(transformedRequestUrl);
7280
return url.searchParams.get('partNumber') ? 'PUT' : 'POST';
@@ -121,36 +129,47 @@ const BedrockAPIConfig: BedrockAPIConfigInterface = {
121129
gatewayRequestURL.split('/v1/files/')[1]
122130
);
123131
const bucketName = s3URL.replace('s3://', '').split('/')[0];
124-
return `https://${bucketName}.s3.${providerOptions.awsRegion || 'us-east-1'}.amazonaws.com`;
132+
return `https://${bucketName}.s3.${providerOptions.awsRegion || 'us-east-1'}.${getAwsEndpointDomain(c)}`;
125133
}
126134
if (fn === 'retrieveFileContent') {
127135
const s3URL = decodeURIComponent(
128136
gatewayRequestURL.split('/v1/files/')[1]
129137
);
130138
const bucketName = s3URL.replace('s3://', '').split('/')[0];
131-
return `https://${bucketName}.s3.${providerOptions.awsRegion || 'us-east-1'}.amazonaws.com`;
139+
return `https://${bucketName}.s3.${providerOptions.awsRegion || 'us-east-1'}.${getAwsEndpointDomain(c)}`;
132140
}
133141
if (fn === 'uploadFile')
134-
return `https://${providerOptions.awsS3Bucket}.s3.${providerOptions.awsRegion || 'us-east-1'}.amazonaws.com`;
142+
return `https://${providerOptions.awsS3Bucket}.s3.${providerOptions.awsRegion || 'us-east-1'}.${getAwsEndpointDomain(c)}`;
135143
const isAWSControlPlaneEndpoint =
136144
fn && AWS_CONTROL_PLANE_ENDPOINTS.includes(fn);
137-
return `https://${isAWSControlPlaneEndpoint ? 'bedrock' : 'bedrock-runtime'}.${providerOptions.awsRegion || 'us-east-1'}.amazonaws.com`;
145+
return `https://${isAWSControlPlaneEndpoint ? 'bedrock' : 'bedrock-runtime'}.${providerOptions.awsRegion || 'us-east-1'}.${getAwsEndpointDomain(c)}`;
138146
},
139147
headers: async ({
140148
c,
141149
fn,
142150
providerOptions,
143151
transformedRequestBody,
144152
transformedRequestUrl,
153+
gatewayRequestBody, // for proxy use the passed body blindly
154+
headers: requestHeaders,
145155
}) => {
146-
const method = getMethod(fn as endpointStrings, transformedRequestUrl);
147-
const service = getService(fn as endpointStrings);
156+
const { awsService } = providerOptions;
157+
const method =
158+
c.get('method') || // method set specifically into context
159+
getMethod(fn as endpointStrings, transformedRequestUrl, c); // method calculated
160+
const service = awsService || getService(fn as endpointStrings);
148161

149-
const headers: Record<string, string> = {
150-
'content-type': 'application/json',
151-
};
162+
let headers: Record<string, string> = {};
152163

153-
if (method === 'PUT' || method === 'GET') {
164+
if (fn === 'proxy' && service !== 'bedrock') {
165+
headers = { ...(requestHeaders ?? {}) };
166+
} else {
167+
headers = {
168+
'content-type': 'application/json',
169+
};
170+
}
171+
172+
if ((method === 'PUT' || method === 'GET') && fn !== 'proxy') {
154173
delete headers['content-type'];
155174
}
156175

@@ -160,7 +179,8 @@ const BedrockAPIConfig: BedrockAPIConfigInterface = {
160179
await providerAssumedRoleCredentials(c, providerOptions);
161180
}
162181

163-
let finalRequestBody = transformedRequestBody;
182+
let finalRequestBody =
183+
fn === 'proxy' ? gatewayRequestBody : transformedRequestBody;
164184

165185
if (['cancelFinetune', 'cancelBatch'].includes(fn as endpointStrings)) {
166186
// Cancel doesn't require any body, but fetch is sending empty body, to match the signature this block is required.
@@ -183,7 +203,6 @@ const BedrockAPIConfig: BedrockAPIConfigInterface = {
183203
fn,
184204
gatewayRequestBodyJSON: gatewayRequestBody,
185205
gatewayRequestURL,
186-
c,
187206
}) => {
188207
if (fn === 'retrieveFile') {
189208
const fileId = decodeURIComponent(

src/providers/bedrock/chatComplete.ts

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import {
22
BEDROCK,
33
documentMimeTypes,
4-
fileExtensionMimeTypeMap,
54
imagesMimeTypes,
5+
fileExtensionMimeTypeMap,
66
} from '../../globals';
77
import {
88
Message,
@@ -62,7 +62,7 @@ export interface BedrockChatCompletionsParams extends Params {
6262
}
6363

6464
export interface BedrockConverseAnthropicChatCompletionsParams
65-
extends Omit<BedrockChatCompletionsParams, 'anthropic_beta'> {
65+
extends BedrockChatCompletionsParams {
6666
anthropic_version?: string;
6767
user?: string;
6868
thinking?: {
@@ -477,9 +477,6 @@ export const BedrockChatCompleteResponseTransform: (
477477
}
478478

479479
if ('output' in response) {
480-
const cacheReadInputTokens = response.usage?.cacheReadInputTokens || 0;
481-
const cacheWriteInputTokens = response.usage?.cacheWriteInputTokens || 0;
482-
483480
let content: string = '';
484481
content = response.output.message.content
485482
.filter((item) => item.text)
@@ -489,6 +486,9 @@ export const BedrockChatCompleteResponseTransform: (
489486
? transformContentBlocks(response.output.message.content)
490487
: undefined;
491488

489+
const cacheReadInputTokens = response.usage?.cacheReadInputTokens || 0;
490+
const cacheWriteInputTokens = response.usage?.cacheWriteInputTokens || 0;
491+
492492
const responseObj: ChatCompletionResponse = {
493493
id: Date.now().toString(),
494494
object: 'chat.completion',
@@ -571,7 +571,6 @@ export const BedrockChatCompleteStreamChunkTransform: (
571571
streamState.currentToolCallIndex = -1;
572572
}
573573

574-
// final chunk
575574
if (parsedChunk.usage) {
576575
const cacheReadInputTokens = parsedChunk.usage?.cacheReadInputTokens || 0;
577576
const cacheWriteInputTokens = parsedChunk.usage?.cacheWriteInputTokens || 0;
@@ -605,9 +604,8 @@ export const BedrockChatCompleteStreamChunkTransform: (
605604
},
606605
// we only want to be sending this for anthropic models and this is not openai compliant
607606
...((cacheReadInputTokens > 0 || cacheWriteInputTokens > 0) && {
608-
cache_read_input_tokens: parsedChunk.usage.cacheReadInputTokens,
609-
cache_creation_input_tokens:
610-
parsedChunk.usage.cacheWriteInputTokens,
607+
cache_read_input_tokens: cacheReadInputTokens,
608+
cache_creation_input_tokens: cacheWriteInputTokens,
611609
}),
612610
},
613611
})}\n\n`,

src/providers/bedrock/constants.ts

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
1+
export const BEDROCK_STABILITY_V1_MODELS = [
2+
'stable-diffusion-xl-v0',
3+
'stable-diffusion-xl-v1',
4+
];
5+
6+
export const bedrockInvokeModels = [
7+
'cohere.command-light-text-v14',
8+
'cohere.command-text-v14',
9+
'ai21.j2-mid-v1',
10+
'ai21.j2-ultra-v1',
11+
];
12+
113
export const LLAMA_2_SPECIAL_TOKENS = {
214
BEGINNING_OF_SENTENCE: '<s>',
315
END_OF_SENTENCE: '</s>',
@@ -34,15 +46,3 @@ export const MISTRAL_CONTROL_TOKENS = {
3446
MIDDLE: '[MIDDLE]',
3547
SUFFIX: '[SUFFIX]',
3648
};
37-
38-
export const BEDROCK_STABILITY_V1_MODELS = [
39-
'stable-diffusion-xl-v0',
40-
'stable-diffusion-xl-v1',
41-
];
42-
43-
export const bedrockInvokeModels = [
44-
'cohere.command-light-text-v14',
45-
'cohere.command-text-v14',
46-
'ai21.j2-mid-v1',
47-
'ai21.j2-ultra-v1',
48-
];

src/providers/bedrock/createFinetune.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ export const BedrockCreateFinetuneConfig: ProviderConfig = {
4646
return undefined;
4747
}
4848
return {
49-
s3Uri: decodeURIComponent(value.validation_file),
49+
s3Uri: decodeURIComponent(value.validation_file ?? ''),
5050
};
5151
},
5252
},

src/providers/bedrock/embed.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ export const BedrockCohereEmbedConfig: ProviderConfig = {
6060
},
6161
};
6262

63+
const g1EmbedModels = [
64+
'amazon.titan-embed-g1-text-02',
65+
'amazon.titan-embed-text-v1',
66+
'amazon.titan-embed-image-v1',
67+
];
68+
6369
export const BedrockTitanEmbedConfig: ProviderConfig = {
6470
input: [
6571
{
@@ -117,6 +123,8 @@ export const BedrockTitanEmbedConfig: ProviderConfig = {
117123
param: 'embeddingTypes',
118124
required: false,
119125
transform: (params: any): string[] | undefined => {
126+
const model = params.foundationModel || params.model || '';
127+
if (g1EmbedModels.includes(model)) return undefined;
120128
if (Array.isArray(params.encoding_format)) return params.encoding_format;
121129
else if (typeof params.encoding_format === 'string')
122130
return [params.encoding_format];

src/providers/bedrock/getBatchOutput.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import { BedrockGetBatchResponse } from './types';
55
import { getOctetStreamToOctetStreamTransformer } from '../../handlers/streamHandlerUtils';
66
import { BedrockUploadFileResponseTransforms } from './uploadFileUtils';
77
import { BEDROCK } from '../../globals';
8+
import { getAwsEndpointDomain } from './utils';
89

910
const getModelProvider = (modelId: string) => {
1011
let provider = '';
@@ -89,7 +90,7 @@ export const BedrockGetBatchOutputRequestHandler = async ({
8990
const awsS3ObjectKey = `${primaryKey}${jobId}/${inputS3URIParts[inputS3URIParts.length - 1]}.out`;
9091
const awsModelProvider = batchDetails.modelId;
9192

92-
const s3FileURL = `https://${awsS3Bucket}.s3.${awsRegion}.amazonaws.com/${awsS3ObjectKey}`;
93+
const s3FileURL = `https://${awsS3Bucket}.s3.${awsRegion}.${getAwsEndpointDomain(c)}/${awsS3ObjectKey}`;
9394
const s3FileHeaders = await BedrockAPIConfig.headers({
9495
c,
9596
providerOptions,

src/providers/bedrock/listBatches.ts

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,8 @@ export const BedrockListBatchesResponseTransform = (
2828
output_file_id: encodeURIComponent(
2929
batch.outputDataConfig.s3OutputDataConfig.s3Uri
3030
),
31-
finalizing_at: batch.endTime
32-
? new Date(batch.endTime).getTime()
33-
: undefined,
34-
expires_at: batch.jobExpirationTime
35-
? new Date(batch.jobExpirationTime).getTime()
36-
: undefined,
31+
finalizing_at: new Date(batch.endTime).getTime(),
32+
expires_at: new Date(batch.jobExpirationTime).getTime(),
3733
}));
3834

3935
return {

src/providers/bedrock/listFinetunes.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ export const BedrockListFinetuneResponseTransform: (
1010
if (responseStatus !== 200) {
1111
return BedrockErrorResponseTransform(response) || response;
1212
}
13+
1314
const records =
1415
response?.modelCustomizationJobSummaries as BedrockFinetuneRecord[];
1516
const openaiRecords = records.map(bedrockFinetuneToOpenAI);

0 commit comments

Comments
 (0)