Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/handlers/responseHandlers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ export async function responseHandler(
provider,
responseTransformerFunction,
requestURL,
strictOpenAiCompliance
strictOpenAiCompliance,
gatewayRequest
),
responseJson: null,
};
Expand Down Expand Up @@ -148,7 +149,8 @@ export async function responseHandler(
response,
responseTransformerFunction,
strictOpenAiCompliance,
gatewayRequestUrl
gatewayRequestUrl,
gatewayRequest
);

return {
Expand Down
38 changes: 27 additions & 11 deletions src/handlers/streamHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
import { VertexLlamaChatCompleteStreamChunkTransform } from '../providers/google-vertex-ai/chatComplete';
import { OpenAIChatCompleteResponse } from '../providers/openai/chatComplete';
import { OpenAICompleteResponse } from '../providers/openai/complete';
import { Params } from '../types/requestBody';
import { getStreamModeSplitPattern, type SplitPatternType } from '../utils';

function readUInt32BE(buffer: Uint8Array, offset: number) {
Expand Down Expand Up @@ -49,7 +50,9 @@ function concatenateUint8Arrays(a: Uint8Array, b: Uint8Array): Uint8Array {
export async function* readAWSStream(
reader: ReadableStreamDefaultReader,
transformFunction: Function | undefined,
fallbackChunkId: string
fallbackChunkId: string,
strictOpenAiCompliance: boolean,
gatewayRequest: Params
) {
let buffer = new Uint8Array();
let expectedLength = 0;
Expand All @@ -68,7 +71,9 @@ export async function* readAWSStream(
const transformedChunk = transformFunction(
payload,
fallbackChunkId,
streamState
streamState,
strictOpenAiCompliance,
gatewayRequest
);
if (Array.isArray(transformedChunk)) {
for (const item of transformedChunk) {
Expand Down Expand Up @@ -102,7 +107,9 @@ export async function* readAWSStream(
const transformedChunk = transformFunction(
payload,
fallbackChunkId,
streamState
streamState,
strictOpenAiCompliance,
gatewayRequest
);
if (Array.isArray(transformedChunk)) {
for (const item of transformedChunk) {
Expand All @@ -124,7 +131,8 @@ export async function* readStream(
transformFunction: Function | undefined,
isSleepTimeRequired: boolean,
fallbackChunkId: string,
strictOpenAiCompliance: boolean
strictOpenAiCompliance: boolean,
gatewayRequest: Params
) {
let buffer = '';
const decoder = new TextDecoder();
Expand All @@ -140,7 +148,8 @@ export async function* readStream(
buffer,
fallbackChunkId,
streamState,
strictOpenAiCompliance
strictOpenAiCompliance,
gatewayRequest
);
} else {
yield buffer;
Expand Down Expand Up @@ -171,7 +180,8 @@ export async function* readStream(
part,
fallbackChunkId,
streamState,
strictOpenAiCompliance
strictOpenAiCompliance,
gatewayRequest
);
if (transformedChunk !== undefined) {
yield transformedChunk;
Expand Down Expand Up @@ -215,7 +225,8 @@ export async function handleNonStreamingMode(
response: Response,
responseTransformer: Function | undefined,
strictOpenAiCompliance: boolean,
gatewayRequestUrl: string
gatewayRequestUrl: string,
gatewayRequest: Params
): Promise<{
response: Response;
json: Record<string, any>;
Expand All @@ -241,7 +252,8 @@ export async function handleNonStreamingMode(
response.status,
response.headers,
strictOpenAiCompliance,
gatewayRequestUrl
gatewayRequestUrl,
gatewayRequest
);
}

Expand Down Expand Up @@ -270,7 +282,8 @@ export function handleStreamingMode(
proxyProvider: string,
responseTransformer: Function | undefined,
requestURL: string,
strictOpenAiCompliance: boolean
strictOpenAiCompliance: boolean,
gatewayRequest: Params
): Response {
const splitPattern = getStreamModeSplitPattern(proxyProvider, requestURL);
// If the provider doesn't supply completion id,
Expand All @@ -291,7 +304,9 @@ export function handleStreamingMode(
for await (const chunk of readAWSStream(
reader,
responseTransformer,
fallbackChunkId
fallbackChunkId,
strictOpenAiCompliance,
gatewayRequest
)) {
await writer.write(encoder.encode(chunk));
}
Expand All @@ -305,7 +320,8 @@ export function handleStreamingMode(
responseTransformer,
isSleepTimeRequired,
fallbackChunkId,
strictOpenAiCompliance
strictOpenAiCompliance,
gatewayRequest
)) {
await writer.write(encoder.encode(chunk));
}
Expand Down
1 change: 0 additions & 1 deletion src/providers/ai21/embed.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ export const AI21EmbedResponseTransform: (
);
if (errorResposne) return errorResposne;
}

if ('results' in response) {
return {
object: 'list',
Expand Down
10 changes: 6 additions & 4 deletions src/providers/anthropic/chatComplete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ export interface AnthropicChatCompleteStreamResponse {
cache_creation_input_tokens?: number;
cache_read_input_tokens?: number;
};
model?: string;
};
error?: AnthropicErrorObject;
}
Expand Down Expand Up @@ -529,7 +530,7 @@ export const AnthropicChatCompleteStreamChunkTransform: (
response: string,
fallbackId: string,
streamState: AnthropicStreamState,
strictOpenAiCompliance: boolean
_strictOpenAiCompliance: boolean
) => string | undefined = (
responseChunk,
fallbackId,
Expand Down Expand Up @@ -585,6 +586,7 @@ export const AnthropicChatCompleteStreamChunkTransform: (
parsedChunk.message?.usage?.cache_creation_input_tokens;

if (parsedChunk.type === 'message_start' && parsedChunk.message?.usage) {
streamState.model = parsedChunk?.message?.model ?? '';
streamState.usage = {
prompt_tokens: parsedChunk.message?.usage?.input_tokens,
...(shouldSendCacheUsage && {
Expand All @@ -599,7 +601,7 @@ export const AnthropicChatCompleteStreamChunkTransform: (
id: fallbackId,
object: 'chat.completion.chunk',
created: Math.floor(Date.now() / 1000),
model: '',
model: streamState.model,
provider: ANTHROPIC,
choices: [
{
Expand All @@ -626,7 +628,7 @@ export const AnthropicChatCompleteStreamChunkTransform: (
id: fallbackId,
object: 'chat.completion.chunk',
created: Math.floor(Date.now() / 1000),
model: '',
model: streamState.model,
provider: ANTHROPIC,
choices: [
{
Expand Down Expand Up @@ -689,7 +691,7 @@ export const AnthropicChatCompleteStreamChunkTransform: (
id: fallbackId,
object: 'chat.completion.chunk',
created: Math.floor(Date.now() / 1000),
model: '',
model: streamState.model,
provider: ANTHROPIC,
choices: [
{
Expand Down
1 change: 1 addition & 0 deletions src/providers/anthropic/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ export type AnthropicStreamState = {
cache_read_input_tokens?: number;
cache_creation_input_tokens?: number;
};
model?: string;
};
66 changes: 48 additions & 18 deletions src/providers/bedrock/chatComplete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -402,12 +402,16 @@ export const BedrockChatCompleteResponseTransform: (
response: BedrockChatCompletionResponse | BedrockErrorResponse,
responseStatus: number,
responseHeaders: Headers,
strictOpenAiCompliance: boolean
strictOpenAiCompliance: boolean,
_gatewayRequestUrl: string,
gatewayRequest: Params
) => ChatCompletionResponse | ErrorResponse = (
response,
responseStatus,
_responseHeaders,
strictOpenAiCompliance
responseHeaders,
strictOpenAiCompliance,
_gatewayRequestUrl,
gatewayRequest
) => {
if (responseStatus !== 200) {
const errorResponse = BedrockErrorResponseTransform(
Expand All @@ -430,7 +434,7 @@ export const BedrockChatCompleteResponseTransform: (
id: Date.now().toString(),
object: 'chat.completion',
created: Math.floor(Date.now() / 1000),
model: '',
model: gatewayRequest.model || '',
provider: BEDROCK,
choices: [
{
Expand Down Expand Up @@ -512,12 +516,14 @@ export const BedrockChatCompleteStreamChunkTransform: (
response: string,
fallbackId: string,
streamState: BedrockStreamState,
strictOpenAiCompliance: boolean
strictOpenAiCompliance: boolean,
gatewayRequest: Params
) => string | string[] = (
responseChunk,
fallbackId,
streamState,
strictOpenAiCompliance
strictOpenAiCompliance,
gatewayRequest
) => {
const parsedChunk: BedrockChatCompleteStreamChunk = JSON.parse(responseChunk);
if (parsedChunk.stopReason) {
Expand All @@ -533,7 +539,7 @@ export const BedrockChatCompleteStreamChunkTransform: (
id: fallbackId,
object: 'chat.completion.chunk',
created: Math.floor(Date.now() / 1000),
model: '',
model: gatewayRequest.model || '',
provider: BEDROCK,
choices: [
{
Expand Down Expand Up @@ -597,7 +603,7 @@ export const BedrockChatCompleteStreamChunkTransform: (
id: fallbackId,
object: 'chat.completion.chunk',
created: Math.floor(Date.now() / 1000),
model: '',
model: gatewayRequest.model || '',
provider: BEDROCK,
choices: [
{
Expand Down Expand Up @@ -802,11 +808,17 @@ export const BedrockCohereChatCompleteConfig: ProviderConfig = {
export const BedrockCohereChatCompleteResponseTransform: (
response: BedrockCohereCompleteResponse | BedrockErrorResponse,
responseStatus: number,
responseHeaders: Headers
responseHeaders: Headers,
strictOpenAiCompliance: boolean,
gatewayRequestUrl: string,
gatewayRequest: Params
) => ChatCompletionResponse | ErrorResponse = (
response,
responseStatus,
responseHeaders
responseHeaders,
_strictOpenAiCompliance,
_gatewayRequestUrl,
gatewayRequest
) => {
if (responseStatus !== 200) {
const errorResposne = BedrockErrorResponseTransform(
Expand All @@ -815,6 +827,8 @@ export const BedrockCohereChatCompleteResponseTransform: (
if (errorResposne) return errorResposne;
}

const model = gatewayRequest.model || '';

if ('generations' in response) {
const prompt_tokens =
Number(responseHeaders.get('X-Amzn-Bedrock-Input-Token-Count')) || 0;
Expand All @@ -824,7 +838,7 @@ export const BedrockCohereChatCompleteResponseTransform: (
id: Date.now().toString(),
object: 'chat.completion',
created: Math.floor(Date.now() / 1000),
model: '',
model,
provider: BEDROCK,
choices: response.generations.map((generation, index) => ({
index: index,
Expand All @@ -847,21 +861,31 @@ export const BedrockCohereChatCompleteResponseTransform: (

export const BedrockCohereChatCompleteStreamChunkTransform: (
response: string,
fallbackId: string
) => string | string[] = (responseChunk, fallbackId) => {
fallbackId: string,
_streamState: Record<string, any>,
_strictOpenAiCompliance: boolean,
gatewayRequest: Params
) => string | string[] = (
responseChunk,
fallbackId,
_streamState,
_strictOpenAiCompliance,
gatewayRequest
) => {
let chunk = responseChunk.trim();
chunk = chunk.replace(/^data: /, '');
chunk = chunk.trim();
const parsedChunk: BedrockCohereStreamChunk = JSON.parse(chunk);

const model = gatewayRequest.model || '';
// discard the last cohere chunk as it sends the whole response combined.
if (parsedChunk.is_finished) {
return [
`data: ${JSON.stringify({
id: fallbackId,
object: 'chat.completion.chunk',
created: Math.floor(Date.now() / 1000),
model: '',
model,
provider: BEDROCK,
choices: [
{
Expand All @@ -888,7 +912,7 @@ export const BedrockCohereChatCompleteStreamChunkTransform: (
id: fallbackId,
object: 'chat.completion.chunk',
created: Math.floor(Date.now() / 1000),
model: '',
model,
provider: BEDROCK,
choices: [
{
Expand Down Expand Up @@ -978,11 +1002,17 @@ export const BedrockAI21ChatCompleteConfig: ProviderConfig = {
export const BedrockAI21ChatCompleteResponseTransform: (
response: BedrockAI21CompleteResponse | BedrockErrorResponse,
responseStatus: number,
responseHeaders: Headers
responseHeaders: Headers,
strictOpenAiCompliance: boolean,
_gatewayRequestUrl: string,
gatewayRequest: Params
) => ChatCompletionResponse | ErrorResponse = (
response,
responseStatus,
responseHeaders
responseHeaders,
_strictOpenAiCompliance,
_gatewayRequestUrl,
gatewayRequest
) => {
if (responseStatus !== 200) {
const errorResposne = BedrockErrorResponseTransform(
Expand All @@ -1000,7 +1030,7 @@ export const BedrockAI21ChatCompleteResponseTransform: (
id: response.id.toString(),
object: 'chat.completion',
created: Math.floor(Date.now() / 1000),
model: '',
model: gatewayRequest.model ?? '',
provider: BEDROCK,
choices: response.completions.map((completion, index) => ({
index: index,
Expand Down
Loading