Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/handlers/handlerUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,6 @@ export function constructConfigFromRequestHeaders(
azureApiVersion: requestHeaders[`x-${POWERED_BY}-azure-api-version`],
azureEndpointName: requestHeaders[`x-${POWERED_BY}-azure-endpoint-name`],
azureFoundryUrl: requestHeaders[`x-${POWERED_BY}-azure-foundry-url`],
azureExtraParams: requestHeaders[`x-${POWERED_BY}-azure-extra-params`],
azureAdToken: requestHeaders[`x-${POWERED_BY}-azure-ad-token`],
azureAuthMode: requestHeaders[`x-${POWERED_BY}-azure-auth-mode`],
azureManagedClientId:
Expand All @@ -867,6 +866,7 @@ export function constructConfigFromRequestHeaders(
requestHeaders[`x-${POWERED_BY}-azure-entra-client-secret`],
azureEntraTenantId: requestHeaders[`x-${POWERED_BY}-azure-entra-tenant-id`],
azureEntraScope: requestHeaders[`x-${POWERED_BY}-azure-entra-scope`],
azureExtraParameters: requestHeaders[`x-${POWERED_BY}-azure-extra-params`],
};

const awsConfig = {
Expand Down
4 changes: 2 additions & 2 deletions src/providers/azure-ai-inference/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ const AzureAIInferenceAPI: ProviderAPIConfig = {
headers: async ({ providerOptions, fn }) => {
const {
apiKey,
azureExtraParams,
azureExtraParameters,
azureDeploymentName,
azureAdToken,
azureAuthMode,
} = providerOptions;

const headers: Record<string, string> = {
'extra-parameters': azureExtraParams ?? 'drop',
'extra-parameters': azureExtraParameters ?? 'drop',
...(azureDeploymentName && {
'azureml-model-deployment': azureDeploymentName,
}),
Expand Down
49 changes: 34 additions & 15 deletions src/providers/bedrock/api.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import { Context } from 'hono';
import { Options } from '../../types/requestBody';
import { Options, Params } from '../../types/requestBody';
import { endpointStrings, ProviderAPIConfig } from '../types';
import { bedrockInvokeModels } from './constants';
import {
getAwsEndpointDomain,
generateAWSHeaders,
getAssumedRoleCredentials,
getFoundationModelFromInferenceProfile,
providerAssumedRoleCredentials,
} from './utils';
Expand All @@ -18,6 +18,7 @@ interface BedrockAPIConfigInterface extends Omit<ProviderAPIConfig, 'headers'> {
transformedRequestBody: Record<string, any> | string;
transformedRequestUrl: string;
gatewayRequestBody?: Params;
headers?: Record<string, string>;
}) => Promise<Record<string, any>> | Record<string, any>;
}

Expand Down Expand Up @@ -66,7 +67,14 @@ const ENDPOINTS_TO_ROUTE_TO_S3 = [
'initiateMultipartUpload',
];

const getMethod = (fn: endpointStrings, transformedRequestUrl: string) => {
const getMethod = (
fn: endpointStrings,
transformedRequestUrl: string,
c: Context
) => {
if (fn === 'proxy') {
return c.req.method;
}
if (fn === 'uploadFile') {
const url = new URL(transformedRequestUrl);
return url.searchParams.get('partNumber') ? 'PUT' : 'POST';
Expand Down Expand Up @@ -121,36 +129,47 @@ const BedrockAPIConfig: BedrockAPIConfigInterface = {
gatewayRequestURL.split('/v1/files/')[1]
);
const bucketName = s3URL.replace('s3://', '').split('/')[0];
return `https://${bucketName}.s3.${providerOptions.awsRegion || 'us-east-1'}.amazonaws.com`;
return `https://${bucketName}.s3.${providerOptions.awsRegion || 'us-east-1'}.${getAwsEndpointDomain(c)}`;
}
if (fn === 'retrieveFileContent') {
const s3URL = decodeURIComponent(
gatewayRequestURL.split('/v1/files/')[1]
);
const bucketName = s3URL.replace('s3://', '').split('/')[0];
return `https://${bucketName}.s3.${providerOptions.awsRegion || 'us-east-1'}.amazonaws.com`;
return `https://${bucketName}.s3.${providerOptions.awsRegion || 'us-east-1'}.${getAwsEndpointDomain(c)}`;
}
if (fn === 'uploadFile')
return `https://${providerOptions.awsS3Bucket}.s3.${providerOptions.awsRegion || 'us-east-1'}.amazonaws.com`;
return `https://${providerOptions.awsS3Bucket}.s3.${providerOptions.awsRegion || 'us-east-1'}.${getAwsEndpointDomain(c)}`;
const isAWSControlPlaneEndpoint =
fn && AWS_CONTROL_PLANE_ENDPOINTS.includes(fn);
return `https://${isAWSControlPlaneEndpoint ? 'bedrock' : 'bedrock-runtime'}.${providerOptions.awsRegion || 'us-east-1'}.amazonaws.com`;
return `https://${isAWSControlPlaneEndpoint ? 'bedrock' : 'bedrock-runtime'}.${providerOptions.awsRegion || 'us-east-1'}.${getAwsEndpointDomain(c)}`;
},
headers: async ({
c,
fn,
providerOptions,
transformedRequestBody,
transformedRequestUrl,
gatewayRequestBody, // for proxy use the passed body blindly
headers: requestHeaders,
}) => {
const method = getMethod(fn as endpointStrings, transformedRequestUrl);
const service = getService(fn as endpointStrings);
const { awsService } = providerOptions;
const method =
c.get('method') || // method set specifically into context
getMethod(fn as endpointStrings, transformedRequestUrl, c); // method calculated
const service = awsService || getService(fn as endpointStrings);

const headers: Record<string, string> = {
'content-type': 'application/json',
};
let headers: Record<string, string> = {};

if (method === 'PUT' || method === 'GET') {
if (fn === 'proxy' && service !== 'bedrock') {
headers = { ...(requestHeaders ?? {}) };
} else {
headers = {
'content-type': 'application/json',
};
}

if ((method === 'PUT' || method === 'GET') && fn !== 'proxy') {
delete headers['content-type'];
}

Expand All @@ -160,7 +179,8 @@ const BedrockAPIConfig: BedrockAPIConfigInterface = {
await providerAssumedRoleCredentials(c, providerOptions);
}

let finalRequestBody = transformedRequestBody;
let finalRequestBody =
fn === 'proxy' ? gatewayRequestBody : transformedRequestBody;

if (['cancelFinetune', 'cancelBatch'].includes(fn as endpointStrings)) {
// Cancel doesn't require any body, but fetch is sending empty body, to match the signature this block is required.
Expand All @@ -183,7 +203,6 @@ const BedrockAPIConfig: BedrockAPIConfigInterface = {
fn,
gatewayRequestBodyJSON: gatewayRequestBody,
gatewayRequestURL,
c,
}) => {
if (fn === 'retrieveFile') {
const fileId = decodeURIComponent(
Expand Down
14 changes: 6 additions & 8 deletions src/providers/bedrock/chatComplete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ export interface BedrockChatCompletionsParams extends Params {
}

export interface BedrockConverseAnthropicChatCompletionsParams
extends Omit<BedrockChatCompletionsParams, 'anthropic_beta'> {
extends BedrockChatCompletionsParams {
anthropic_version?: string;
user?: string;
thinking?: {
Expand Down Expand Up @@ -511,9 +511,6 @@ export const BedrockChatCompleteResponseTransform: (
}

if ('output' in response) {
const cacheReadInputTokens = response.usage?.cacheReadInputTokens || 0;
const cacheWriteInputTokens = response.usage?.cacheWriteInputTokens || 0;

let content: string = '';
content = response.output.message.content
.filter((item) => item.text)
Expand All @@ -523,6 +520,9 @@ export const BedrockChatCompleteResponseTransform: (
? transformContentBlocks(response.output.message.content)
: undefined;

const cacheReadInputTokens = response.usage?.cacheReadInputTokens || 0;
const cacheWriteInputTokens = response.usage?.cacheWriteInputTokens || 0;

const responseObj: ChatCompletionResponse = {
id: Date.now().toString(),
object: 'chat.completion',
Expand Down Expand Up @@ -605,7 +605,6 @@ export const BedrockChatCompleteStreamChunkTransform: (
streamState.currentToolCallIndex = -1;
}

// final chunk
if (parsedChunk.usage) {
const cacheReadInputTokens = parsedChunk.usage?.cacheReadInputTokens || 0;
const cacheWriteInputTokens = parsedChunk.usage?.cacheWriteInputTokens || 0;
Expand Down Expand Up @@ -639,9 +638,8 @@ export const BedrockChatCompleteStreamChunkTransform: (
},
// we only want to be sending this for anthropic models and this is not openai compliant
...((cacheReadInputTokens > 0 || cacheWriteInputTokens > 0) && {
cache_read_input_tokens: parsedChunk.usage.cacheReadInputTokens,
cache_creation_input_tokens:
parsedChunk.usage.cacheWriteInputTokens,
cache_read_input_tokens: cacheReadInputTokens,
cache_creation_input_tokens: cacheWriteInputTokens,
}),
},
})}\n\n`,
Expand Down
24 changes: 12 additions & 12 deletions src/providers/bedrock/constants.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
export const BEDROCK_STABILITY_V1_MODELS = [
'stable-diffusion-xl-v0',
'stable-diffusion-xl-v1',
];

export const bedrockInvokeModels = [
'cohere.command-light-text-v14',
'cohere.command-text-v14',
'ai21.j2-mid-v1',
'ai21.j2-ultra-v1',
];

export const LLAMA_2_SPECIAL_TOKENS = {
BEGINNING_OF_SENTENCE: '<s>',
END_OF_SENTENCE: '</s>',
Expand Down Expand Up @@ -34,15 +46,3 @@ export const MISTRAL_CONTROL_TOKENS = {
MIDDLE: '[MIDDLE]',
SUFFIX: '[SUFFIX]',
};

export const BEDROCK_STABILITY_V1_MODELS = [
'stable-diffusion-xl-v0',
'stable-diffusion-xl-v1',
];

export const bedrockInvokeModels = [
'cohere.command-light-text-v14',
'cohere.command-text-v14',
'ai21.j2-mid-v1',
'ai21.j2-ultra-v1',
];
2 changes: 1 addition & 1 deletion src/providers/bedrock/createFinetune.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ export const BedrockCreateFinetuneConfig: ProviderConfig = {
return undefined;
}
return {
s3Uri: decodeURIComponent(value.validation_file),
s3Uri: decodeURIComponent(value.validation_file ?? ''),
};
},
},
Expand Down
8 changes: 8 additions & 0 deletions src/providers/bedrock/embed.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ export const BedrockCohereEmbedConfig: ProviderConfig = {
},
};

const g1EmbedModels = [
'amazon.titan-embed-g1-text-02',
'amazon.titan-embed-text-v1',
'amazon.titan-embed-image-v1',
];

export const BedrockTitanEmbedConfig: ProviderConfig = {
input: [
{
Expand Down Expand Up @@ -117,6 +123,8 @@ export const BedrockTitanEmbedConfig: ProviderConfig = {
param: 'embeddingTypes',
required: false,
transform: (params: any): string[] | undefined => {
const model = params.foundationModel || params.model || '';
if (g1EmbedModels.includes(model)) return undefined;
if (Array.isArray(params.encoding_format)) return params.encoding_format;
else if (typeof params.encoding_format === 'string')
return [params.encoding_format];
Expand Down
3 changes: 2 additions & 1 deletion src/providers/bedrock/getBatchOutput.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { BedrockGetBatchResponse } from './types';
import { getOctetStreamToOctetStreamTransformer } from '../../handlers/streamHandlerUtils';
import { BedrockUploadFileResponseTransforms } from './uploadFileUtils';
import { BEDROCK } from '../../globals';
import { getAwsEndpointDomain } from './utils';

const getModelProvider = (modelId: string) => {
let provider = '';
Expand Down Expand Up @@ -89,7 +90,7 @@ export const BedrockGetBatchOutputRequestHandler = async ({
const awsS3ObjectKey = `${primaryKey}${jobId}/${inputS3URIParts[inputS3URIParts.length - 1]}.out`;
const awsModelProvider = batchDetails.modelId;

const s3FileURL = `https://${awsS3Bucket}.s3.${awsRegion}.amazonaws.com/${awsS3ObjectKey}`;
const s3FileURL = `https://${awsS3Bucket}.s3.${awsRegion}.${getAwsEndpointDomain(c)}/${awsS3ObjectKey}`;
const s3FileHeaders = await BedrockAPIConfig.headers({
c,
providerOptions,
Expand Down
8 changes: 2 additions & 6 deletions src/providers/bedrock/listBatches.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,8 @@ export const BedrockListBatchesResponseTransform = (
output_file_id: encodeURIComponent(
batch.outputDataConfig.s3OutputDataConfig.s3Uri
),
finalizing_at: batch.endTime
? new Date(batch.endTime).getTime()
: undefined,
expires_at: batch.jobExpirationTime
? new Date(batch.jobExpirationTime).getTime()
: undefined,
finalizing_at: new Date(batch.endTime).getTime(),
expires_at: new Date(batch.jobExpirationTime).getTime(),
}));

return {
Expand Down
1 change: 1 addition & 0 deletions src/providers/bedrock/listFinetunes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export const BedrockListFinetuneResponseTransform: (
if (responseStatus !== 200) {
return BedrockErrorResponseTransform(response) || response;
}

const records =
response?.modelCustomizationJobSummaries as BedrockFinetuneRecord[];
const openaiRecords = records.map(bedrockFinetuneToOpenAI);
Expand Down
10 changes: 10 additions & 0 deletions src/providers/bedrock/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,16 @@ export interface BedrockInferenceProfile {
type: string;
}

// https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html#API_runtime_Converse_ResponseSyntax
export enum BEDROCK_STOP_REASON {
end_turn = 'end_turn',
tool_use = 'tool_use',
max_tokens = 'max_tokens',
stop_sequence = 'stop_sequence',
guardrail_intervened = 'guardrail_intervened',
content_filtered = 'content_filtered',
}

export interface BedrockMessagesParams extends MessageCreateParamsBase {
additionalModelRequestFields?: Record<string, any>;
additional_model_request_fields?: Record<string, any>;
Expand Down
12 changes: 9 additions & 3 deletions src/providers/bedrock/uploadFileUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,10 @@ interface BedrockAnthropicChatCompleteResponse {
stop_reason: string;
model: string;
stop_sequence: null | string;
usage: {
input_tokens: number;
output_tokens: number;
};
}

export const BedrockAnthropicChatCompleteResponseTransform: (
Expand Down Expand Up @@ -874,9 +878,10 @@ export const BedrockAnthropicChatCompleteResponseTransform: (
},
],
usage: {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
prompt_tokens: response.usage.input_tokens,
completion_tokens: response.usage.output_tokens,
total_tokens:
response.usage.input_tokens + response.usage.output_tokens,
},
};
}
Expand Down Expand Up @@ -933,6 +938,7 @@ export const BedrockMistralChatCompleteResponseTransform: (
finish_reason: response.outputs[0].stop_reason,
},
],
// mistral not sending usage.
usage: {
prompt_tokens: 0,
completion_tokens: 0,
Expand Down
4 changes: 4 additions & 0 deletions src/providers/bedrock/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ import { GatewayError } from '../../errors/GatewayError';
import { BedrockFinetuneRecord, BedrockInferenceProfile } from './types';
import { FinetuneRequest } from '../types';
import { BEDROCK } from '../../globals';
import { Environment } from '../../utils/env';

export const getAwsEndpointDomain = (c: Context) =>
Environment(c).AWS_ENDPOINT_DOMAIN || 'amazonaws.com';

export const generateAWSHeaders = async (
body: Record<string, any> | string | undefined,
Expand Down
Loading