Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/handlers/handlerUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,6 @@ export function constructConfigFromRequestHeaders(
azureApiVersion: requestHeaders[`x-${POWERED_BY}-azure-api-version`],
azureEndpointName: requestHeaders[`x-${POWERED_BY}-azure-endpoint-name`],
azureFoundryUrl: requestHeaders[`x-${POWERED_BY}-azure-foundry-url`],
azureExtraParams: requestHeaders[`x-${POWERED_BY}-azure-extra-params`],
azureAdToken: requestHeaders[`x-${POWERED_BY}-azure-ad-token`],
azureAuthMode: requestHeaders[`x-${POWERED_BY}-azure-auth-mode`],
azureManagedClientId:
Expand All @@ -867,6 +866,7 @@ export function constructConfigFromRequestHeaders(
requestHeaders[`x-${POWERED_BY}-azure-entra-client-secret`],
azureEntraTenantId: requestHeaders[`x-${POWERED_BY}-azure-entra-tenant-id`],
azureEntraScope: requestHeaders[`x-${POWERED_BY}-azure-entra-scope`],
azureExtraParameters: requestHeaders[`x-${POWERED_BY}-azure-extra-params`],
};

const awsConfig = {
Expand Down
4 changes: 2 additions & 2 deletions src/providers/azure-ai-inference/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ const AzureAIInferenceAPI: ProviderAPIConfig = {
headers: async ({ providerOptions, fn }) => {
const {
apiKey,
azureExtraParams,
azureExtraParameters,
azureDeploymentName,
azureAdToken,
azureAuthMode,
} = providerOptions;

const headers: Record<string, string> = {
'extra-parameters': azureExtraParams ?? 'drop',
'extra-parameters': azureExtraParameters ?? 'drop',
...(azureDeploymentName && {
'azureml-model-deployment': azureDeploymentName,
}),
Expand Down
49 changes: 34 additions & 15 deletions src/providers/bedrock/api.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import { Context } from 'hono';
import { Options } from '../../types/requestBody';
import { Options, Params } from '../../types/requestBody';
import { endpointStrings, ProviderAPIConfig } from '../types';
import { bedrockInvokeModels } from './constants';
import {
getAwsEndpointDomain,
generateAWSHeaders,
getAssumedRoleCredentials,
getFoundationModelFromInferenceProfile,
providerAssumedRoleCredentials,
} from './utils';
Expand All @@ -18,6 +18,7 @@ interface BedrockAPIConfigInterface extends Omit<ProviderAPIConfig, 'headers'> {
transformedRequestBody: Record<string, any> | string;
transformedRequestUrl: string;
gatewayRequestBody?: Params;
headers?: Record<string, string>;
}) => Promise<Record<string, any>> | Record<string, any>;
}

Expand Down Expand Up @@ -66,7 +67,14 @@ const ENDPOINTS_TO_ROUTE_TO_S3 = [
'initiateMultipartUpload',
];

const getMethod = (fn: endpointStrings, transformedRequestUrl: string) => {
const getMethod = (
fn: endpointStrings,
transformedRequestUrl: string,
c: Context
) => {
if (fn === 'proxy') {
return c.req.method;
}
if (fn === 'uploadFile') {
const url = new URL(transformedRequestUrl);
return url.searchParams.get('partNumber') ? 'PUT' : 'POST';
Expand Down Expand Up @@ -121,36 +129,47 @@ const BedrockAPIConfig: BedrockAPIConfigInterface = {
gatewayRequestURL.split('/v1/files/')[1]
);
const bucketName = s3URL.replace('s3://', '').split('/')[0];
return `https://${bucketName}.s3.${providerOptions.awsRegion || 'us-east-1'}.amazonaws.com`;
return `https://${bucketName}.s3.${providerOptions.awsRegion || 'us-east-1'}.${getAwsEndpointDomain(c)}`;
}
if (fn === 'retrieveFileContent') {
const s3URL = decodeURIComponent(
gatewayRequestURL.split('/v1/files/')[1]
);
const bucketName = s3URL.replace('s3://', '').split('/')[0];
return `https://${bucketName}.s3.${providerOptions.awsRegion || 'us-east-1'}.amazonaws.com`;
return `https://${bucketName}.s3.${providerOptions.awsRegion || 'us-east-1'}.${getAwsEndpointDomain(c)}`;
}
if (fn === 'uploadFile')
return `https://${providerOptions.awsS3Bucket}.s3.${providerOptions.awsRegion || 'us-east-1'}.amazonaws.com`;
return `https://${providerOptions.awsS3Bucket}.s3.${providerOptions.awsRegion || 'us-east-1'}.${getAwsEndpointDomain(c)}`;
const isAWSControlPlaneEndpoint =
fn && AWS_CONTROL_PLANE_ENDPOINTS.includes(fn);
return `https://${isAWSControlPlaneEndpoint ? 'bedrock' : 'bedrock-runtime'}.${providerOptions.awsRegion || 'us-east-1'}.amazonaws.com`;
return `https://${isAWSControlPlaneEndpoint ? 'bedrock' : 'bedrock-runtime'}.${providerOptions.awsRegion || 'us-east-1'}.${getAwsEndpointDomain(c)}`;
},
headers: async ({
c,
fn,
providerOptions,
transformedRequestBody,
transformedRequestUrl,
gatewayRequestBody, // for proxy use the passed body blindly
headers: requestHeaders,
}) => {
const method = getMethod(fn as endpointStrings, transformedRequestUrl);
const service = getService(fn as endpointStrings);
const { awsService } = providerOptions;
const method =
c.get('method') || // method set specifically into context
getMethod(fn as endpointStrings, transformedRequestUrl, c); // method calculated
const service = awsService || getService(fn as endpointStrings);

const headers: Record<string, string> = {
'content-type': 'application/json',
};
let headers: Record<string, string> = {};

if (method === 'PUT' || method === 'GET') {
if (fn === 'proxy' && service !== 'bedrock') {
headers = { ...(requestHeaders ?? {}) };
} else {
headers = {
'content-type': 'application/json',
};
}

if ((method === 'PUT' || method === 'GET') && fn !== 'proxy') {
delete headers['content-type'];
}

Expand All @@ -160,7 +179,8 @@ const BedrockAPIConfig: BedrockAPIConfigInterface = {
await providerAssumedRoleCredentials(c, providerOptions);
}

let finalRequestBody = transformedRequestBody;
let finalRequestBody =
fn === 'proxy' ? gatewayRequestBody : transformedRequestBody;

if (['cancelFinetune', 'cancelBatch'].includes(fn as endpointStrings)) {
// Cancel doesn't require any body, but fetch is sending empty body, to match the signature this block is required.
Expand All @@ -183,7 +203,6 @@ const BedrockAPIConfig: BedrockAPIConfigInterface = {
fn,
gatewayRequestBodyJSON: gatewayRequestBody,
gatewayRequestURL,
c,
}) => {
if (fn === 'retrieveFile') {
const fileId = decodeURIComponent(
Expand Down
14 changes: 6 additions & 8 deletions src/providers/bedrock/chatComplete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ export interface BedrockChatCompletionsParams extends Params {
}

export interface BedrockConverseAnthropicChatCompletionsParams
extends Omit<BedrockChatCompletionsParams, 'anthropic_beta'> {
extends BedrockChatCompletionsParams {
anthropic_version?: string;
user?: string;
thinking?: {
Expand Down Expand Up @@ -511,9 +511,6 @@ export const BedrockChatCompleteResponseTransform: (
}

if ('output' in response) {
const cacheReadInputTokens = response.usage?.cacheReadInputTokens || 0;
const cacheWriteInputTokens = response.usage?.cacheWriteInputTokens || 0;

let content: string = '';
content = response.output.message.content
.filter((item) => item.text)
Expand All @@ -523,6 +520,9 @@ export const BedrockChatCompleteResponseTransform: (
? transformContentBlocks(response.output.message.content)
: undefined;

const cacheReadInputTokens = response.usage?.cacheReadInputTokens || 0;
const cacheWriteInputTokens = response.usage?.cacheWriteInputTokens || 0;

const responseObj: ChatCompletionResponse = {
id: Date.now().toString(),
object: 'chat.completion',
Expand Down Expand Up @@ -605,7 +605,6 @@ export const BedrockChatCompleteStreamChunkTransform: (
streamState.currentToolCallIndex = -1;
}

// final chunk
if (parsedChunk.usage) {
const cacheReadInputTokens = parsedChunk.usage?.cacheReadInputTokens || 0;
const cacheWriteInputTokens = parsedChunk.usage?.cacheWriteInputTokens || 0;
Expand Down Expand Up @@ -639,9 +638,8 @@ export const BedrockChatCompleteStreamChunkTransform: (
},
// we only want to be sending this for anthropic models and this is not openai compliant
...((cacheReadInputTokens > 0 || cacheWriteInputTokens > 0) && {
cache_read_input_tokens: parsedChunk.usage.cacheReadInputTokens,
cache_creation_input_tokens:
parsedChunk.usage.cacheWriteInputTokens,
cache_read_input_tokens: cacheReadInputTokens,
cache_creation_input_tokens: cacheWriteInputTokens,
}),
},
})}\n\n`,
Expand Down
24 changes: 12 additions & 12 deletions src/providers/bedrock/constants.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
export const BEDROCK_STABILITY_V1_MODELS = [
'stable-diffusion-xl-v0',
'stable-diffusion-xl-v1',
];

export const bedrockInvokeModels = [
'cohere.command-light-text-v14',
'cohere.command-text-v14',
'ai21.j2-mid-v1',
'ai21.j2-ultra-v1',
];

export const LLAMA_2_SPECIAL_TOKENS = {
BEGINNING_OF_SENTENCE: '<s>',
END_OF_SENTENCE: '</s>',
Expand Down Expand Up @@ -34,15 +46,3 @@ export const MISTRAL_CONTROL_TOKENS = {
MIDDLE: '[MIDDLE]',
SUFFIX: '[SUFFIX]',
};

export const BEDROCK_STABILITY_V1_MODELS = [
'stable-diffusion-xl-v0',
'stable-diffusion-xl-v1',
];

export const bedrockInvokeModels = [
'cohere.command-light-text-v14',
'cohere.command-text-v14',
'ai21.j2-mid-v1',
'ai21.j2-ultra-v1',
];
2 changes: 1 addition & 1 deletion src/providers/bedrock/createFinetune.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ export const BedrockCreateFinetuneConfig: ProviderConfig = {
return undefined;
}
return {
s3Uri: decodeURIComponent(value.validation_file),
s3Uri: decodeURIComponent(value.validation_file ?? ''),
};
},
},
Expand Down
8 changes: 8 additions & 0 deletions src/providers/bedrock/embed.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ export const BedrockCohereEmbedConfig: ProviderConfig = {
},
};

const g1EmbedModels = [
'amazon.titan-embed-g1-text-02',
'amazon.titan-embed-text-v1',
'amazon.titan-embed-image-v1',
];

export const BedrockTitanEmbedConfig: ProviderConfig = {
input: [
{
Expand Down Expand Up @@ -117,6 +123,8 @@ export const BedrockTitanEmbedConfig: ProviderConfig = {
param: 'embeddingTypes',
required: false,
transform: (params: any): string[] | undefined => {
const model = params.foundationModel || params.model || '';
if (g1EmbedModels.includes(model)) return undefined;
if (Array.isArray(params.encoding_format)) return params.encoding_format;
else if (typeof params.encoding_format === 'string')
return [params.encoding_format];
Expand Down
3 changes: 2 additions & 1 deletion src/providers/bedrock/getBatchOutput.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { BedrockGetBatchResponse } from './types';
import { getOctetStreamToOctetStreamTransformer } from '../../handlers/streamHandlerUtils';
import { BedrockUploadFileResponseTransforms } from './uploadFileUtils';
import { BEDROCK } from '../../globals';
import { getAwsEndpointDomain } from './utils';

const getModelProvider = (modelId: string) => {
let provider = '';
Expand Down Expand Up @@ -89,7 +90,7 @@ export const BedrockGetBatchOutputRequestHandler = async ({
const awsS3ObjectKey = `${primaryKey}${jobId}/${inputS3URIParts[inputS3URIParts.length - 1]}.out`;
const awsModelProvider = batchDetails.modelId;

const s3FileURL = `https://${awsS3Bucket}.s3.${awsRegion}.amazonaws.com/${awsS3ObjectKey}`;
const s3FileURL = `https://${awsS3Bucket}.s3.${awsRegion}.${getAwsEndpointDomain(c)}/${awsS3ObjectKey}`;
const s3FileHeaders = await BedrockAPIConfig.headers({
c,
providerOptions,
Expand Down
8 changes: 2 additions & 6 deletions src/providers/bedrock/listBatches.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,8 @@ export const BedrockListBatchesResponseTransform = (
output_file_id: encodeURIComponent(
batch.outputDataConfig.s3OutputDataConfig.s3Uri
),
finalizing_at: batch.endTime
? new Date(batch.endTime).getTime()
: undefined,
expires_at: batch.jobExpirationTime
? new Date(batch.jobExpirationTime).getTime()
: undefined,
finalizing_at: new Date(batch.endTime).getTime(),
expires_at: new Date(batch.jobExpirationTime).getTime(),
}));

return {
Expand Down
1 change: 1 addition & 0 deletions src/providers/bedrock/listFinetunes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export const BedrockListFinetuneResponseTransform: (
if (responseStatus !== 200) {
return BedrockErrorResponseTransform(response) || response;
}

const records =
response?.modelCustomizationJobSummaries as BedrockFinetuneRecord[];
const openaiRecords = records.map(bedrockFinetuneToOpenAI);
Expand Down
10 changes: 10 additions & 0 deletions src/providers/bedrock/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,16 @@ export interface BedrockInferenceProfile {
type: string;
}

// https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html#API_runtime_Converse_ResponseSyntax
export enum BEDROCK_STOP_REASON {
end_turn = 'end_turn',
tool_use = 'tool_use',
max_tokens = 'max_tokens',
stop_sequence = 'stop_sequence',
guardrail_intervened = 'guardrail_intervened',
content_filtered = 'content_filtered',
}

export interface BedrockMessagesParams extends MessageCreateParamsBase {
additionalModelRequestFields?: Record<string, any>;
additional_model_request_fields?: Record<string, any>;
Expand Down
12 changes: 9 additions & 3 deletions src/providers/bedrock/uploadFileUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,10 @@ interface BedrockAnthropicChatCompleteResponse {
stop_reason: string;
model: string;
stop_sequence: null | string;
usage: {
input_tokens: number;
output_tokens: number;
};
}

export const BedrockAnthropicChatCompleteResponseTransform: (
Expand Down Expand Up @@ -874,9 +878,10 @@ export const BedrockAnthropicChatCompleteResponseTransform: (
},
],
usage: {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
prompt_tokens: response.usage.input_tokens,
completion_tokens: response.usage.output_tokens,
total_tokens:
response.usage.input_tokens + response.usage.output_tokens,
},
};
}
Expand Down Expand Up @@ -933,6 +938,7 @@ export const BedrockMistralChatCompleteResponseTransform: (
finish_reason: response.outputs[0].stop_reason,
},
],
// mistral not sending usage.
usage: {
prompt_tokens: 0,
completion_tokens: 0,
Expand Down
4 changes: 4 additions & 0 deletions src/providers/bedrock/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ import { GatewayError } from '../../errors/GatewayError';
import { BedrockFinetuneRecord, BedrockInferenceProfile } from './types';
import { FinetuneRequest } from '../types';
import { BEDROCK } from '../../globals';
import { Environment } from '../../utils/env';

export const getAwsEndpointDomain = (c: Context) =>
Environment(c).AWS_ENDPOINT_DOMAIN || 'amazonaws.com';

export const generateAWSHeaders = async (
body: Record<string, any> | string | undefined,
Expand Down
Loading