Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/errors/GatewayError.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
export class GatewayError extends Error {
constructor(
message: string,
public status: number = 500,
public cause?: Error
) {
super(message);
this.name = 'GatewayError';
this.status = status;
}
Comment on lines 2 to 10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🐛 Bug Fix

Issue: Redundant assignment of status property in constructor
Fix: Remove redundant assignment since it's already initialized in parameter
Impact: Cleaner code, no functional change

Suggested change
constructor(
message: string,
public status: number = 500,
public cause?: Error
) {
super(message);
this.name = 'GatewayError';
this.status = status;
}
constructor(
message: string,
public status: number = 500,
public cause?: Error
) {
super(message);
this.name = 'GatewayError';
}

}
2 changes: 1 addition & 1 deletion src/handlers/handlerUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,7 @@ export async function tryTargetsRecursively(
message: errorMessage,
}),
{
status: 500,
status: error instanceof GatewayError ? error.status : 500,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Code Refactor

Issue: Hardcoded error status code
Fix: Use status from GatewayError if available, otherwise default to 500
Impact: More accurate error status propagation

Suggested change
status: error instanceof GatewayError ? error.status : 500,
status: error instanceof GatewayError ? error.status : 500,

headers: {
'content-type': 'application/json',
// Add this header so that the fallback loop can be interrupted if its an exception.
Expand Down
25 changes: 18 additions & 7 deletions src/providers/google-vertex-ai/api.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import { GatewayError } from '../../errors/GatewayError';
import { Options } from '../../types/requestBody';
import { endpointStrings, ProviderAPIConfig } from '../types';
import { getModelAndProvider, getAccessToken, getBucketAndFile } from './utils';

const getApiVersion = (provider: string, inputModel: string) => {
const getApiVersion = (provider: string) => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Code Refactor

Issue: Unused parameter inputModel in getApiVersion function
Fix: Remove unused parameter
Impact: Cleaner function signature

Suggested change
const getApiVersion = (provider: string) => {
const getApiVersion = (provider: string) => {

if (provider === 'meta') return 'v1beta1';
return 'v1';
};
Expand All @@ -22,7 +23,7 @@ const getProjectRoute = (
}

const { provider } = getModelAndProvider(inputModel as string);
let routeVersion = getApiVersion(provider, inputModel as string);
const routeVersion = getApiVersion(provider);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Code Refactor

Issue: Unused parameter in function call
Fix: Remove unused inputModel parameter
Impact: Cleaner function call

Suggested change
const routeVersion = getApiVersion(provider);
const routeVersion = getApiVersion(provider);

return `/${routeVersion}/projects/${projectId}/locations/${vertexRegion}`;
};

Expand Down Expand Up @@ -58,8 +59,9 @@ export const GoogleApiConfig: ProviderAPIConfig = {
}

if (vertexRegion === 'global') {
return `https://aiplatform.googleapis.com`;
return 'https://aiplatform.googleapis.com';
}

return `https://${vertexRegion}-aiplatform.googleapis.com`;
},
headers: async ({ c, providerOptions, gatewayRequestBody }) => {
Expand All @@ -68,7 +70,6 @@ export const GoogleApiConfig: ProviderAPIConfig = {
if (vertexServiceAccountJson) {
authToken = await getAccessToken(c, vertexServiceAccountJson);
}

const anthropicBeta =
providerOptions?.['anthropicBeta'] ??
gatewayRequestBody?.['anthropic_beta'];
Expand All @@ -95,6 +96,9 @@ export const GoogleApiConfig: ProviderAPIConfig = {
mappedFn = `stream-${fn}` as endpointStrings;
}

const url = new URL(gatewayRequestURL);
const searchParams = url.searchParams;

if (NON_INFERENCE_ENDPOINTS.includes(fn)) {
const jobIdIndex = [
'cancelBatch',
Expand All @@ -106,9 +110,9 @@ export const GoogleApiConfig: ProviderAPIConfig = {
const jobId = gatewayRequestURL.split('/').at(jobIdIndex);

const url = new URL(gatewayRequestURL);
const searchParams = url.searchParams;
const pageSize = searchParams.get('limit') ?? 20;
const after = searchParams.get('after') ?? '';
const params = new URLSearchParams(url.search);
const pageSize = params.get('limit') ?? 20;
const after = params.get('after') ?? '';

let projectId = vertexProjectId;
if (!projectId || vertexServiceAccountJson) {
Expand Down Expand Up @@ -147,9 +151,15 @@ export const GoogleApiConfig: ProviderAPIConfig = {
case 'cancelFinetune': {
return `/v1/projects/${projectId}/locations/${vertexRegion}/tuningJobs/${jobId}:cancel`;
}
default:
return '';
Comment on lines +154 to +155
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Code Refactor

Issue: Missing default case in switch statement
Fix: Add default case returning empty string
Impact: Prevents potential undefined returns

Suggested change
default:
return '';
default:
return '';

}
}

if (!inputModel) {
throw new GatewayError('Model is required', 400);
}
Comment on lines +159 to +161
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔒 Security Issue Fix

Issue: Missing model validation could lead to unexpected behavior
Fix: Add explicit model validation with GatewayError
Impact: Prevents processing requests without required model parameter

Suggested change
if (!inputModel) {
throw new GatewayError('Model is required', 400);
}
if (!inputModel) {
throw new GatewayError('Model is required', 400);
}


const { provider, model } = getModelAndProvider(inputModel as string);
const projectRoute = getProjectRoute(providerOptions, inputModel as string);
const googleUrlMap = new Map<string, string>([
Expand Down Expand Up @@ -188,6 +198,7 @@ export const GoogleApiConfig: ProviderAPIConfig = {
} else if (mappedFn === 'messagesCountTokens') {
return `${projectRoute}/publishers/${provider}/models/count-tokens:rawPredict`;
}
return `${projectRoute}/publishers/${provider}/models/${model}:rawPredict`;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Code Refactor

Issue: Missing default return in switch statement
Fix: Add default return for rawPredict endpoint
Impact: Ensures consistent return value

Suggested change
return `${projectRoute}/publishers/${provider}/models/${model}:rawPredict`;
return `${projectRoute}/publishers/${provider}/models/${model}:rawPredict`;

}

case 'meta': {
Expand Down
17 changes: 12 additions & 5 deletions src/providers/google-vertex-ai/chatComplete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ import {
AnthropicChatCompleteStreamResponse,
} from '../anthropic/chatComplete';
import {
AnthropicErrorResponse,
AnthropicStreamState,
AnthropicErrorResponse,
} from '../anthropic/types';
import {
GoogleMessage,
Expand All @@ -28,6 +28,7 @@ import {
transformOpenAIRoleToGoogleRole,
transformToolChoiceForGemini,
} from '../google/chatComplete';
import { GOOGLE_GENERATE_CONTENT_FINISH_REASON } from '../google/types';
import {
ChatCompletionResponse,
ErrorResponse,
Expand Down Expand Up @@ -295,7 +296,13 @@ export const VertexGoogleChatCompleteConfig: ProviderConfig = {
delete tool.function?.strict;

if (['googleSearch', 'google_search'].includes(tool.function.name)) {
tools.push({ googleSearch: {} });
const timeRangeFilter = tool.function.parameters?.timeRangeFilter;
tools.push({
googleSearch: {
// allow null
...(timeRangeFilter !== undefined && { timeRangeFilter }),
},
});
} else if (
['googleSearchRetrieval', 'google_search_retrieval'].includes(
tool.function.name
Expand Down Expand Up @@ -516,7 +523,7 @@ export const GoogleChatCompleteResponseTransform: (
message: message,
index: index,
finish_reason: transformFinishReason(
generation.finishReason,
generation.finishReason as GOOGLE_GENERATE_CONTENT_FINISH_REASON,
strictOpenAiCompliance
),
logprobs,
Expand Down Expand Up @@ -641,11 +648,11 @@ export const GoogleChatCompleteStreamChunkTransform: (
parsedChunk.candidates?.map((generation, index) => {
const finishReason = generation.finishReason
? transformFinishReason(
parsedChunk.candidates[0].finishReason,
parsedChunk.candidates[0]
.finishReason as GOOGLE_GENERATE_CONTENT_FINISH_REASON,
strictOpenAiCompliance
)
: null;

let message: any = { role: 'assistant', content: '' };
if (generation.content?.parts[0]?.text) {
const contentBlocks = [];
Expand Down
24 changes: 24 additions & 0 deletions src/providers/google-vertex-ai/createBatch.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import { constructConfigFromRequestHeaders } from '../../handlers/handlerUtils';
import { transformUsingProviderConfig } from '../../services/transformToProviderRequest';
import { Options } from '../../types/requestBody';
import { ProviderConfig } from '../types';
import { GoogleBatchRecord } from './types';
import { getModelAndProvider, GoogleToOpenAIBatch } from './utils';
Expand Down Expand Up @@ -69,6 +72,27 @@ export const GoogleBatchCreateConfig: ProviderConfig = {
},
};

export const GoogleBatchCreateRequestTransform = (
requestBody: any,
requestHeaders: Record<string, string>
) => {
const providerOptions = constructConfigFromRequestHeaders(requestHeaders);

const baseConfig = transformUsingProviderConfig(
GoogleBatchCreateConfig,
requestBody,
providerOptions as Options
);

const finalBody = {
// Contains extra fields like tags etc, also might contains model etc, so order is important to override the fields with params created using config.
...requestBody?.provider_options,
...baseConfig,
};

return finalBody;
};

export const GoogleBatchCreateResponseTransform = (
response: Response,
responseStatus: number
Expand Down
27 changes: 14 additions & 13 deletions src/providers/google-vertex-ai/embed.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {
transformEmbeddingInputs,
transformEmbeddingsParameters,
} from './transformGenerationConfig';
import { Params } from '../../types/requestBody';

enum TASK_TYPE {
RETRIEVAL_QUERY = 'RETRIEVAL_QUERY',
Expand Down Expand Up @@ -49,6 +50,19 @@ export const GoogleEmbedConfig: ProviderConfig = {
},
};

export const VertexBatchEmbedConfig: ProviderConfig = {
input: {
param: 'content',
required: true,
transform: (value: EmbedParams) => {
if (typeof value.input === 'string') {
return value.input;
}
return value.input.map((item) => item).join('\n');
},
},
};

export const GoogleEmbedResponseTransform: (
response: GoogleEmbedResponse | GoogleErrorResponse,
responseStatus: number,
Expand Down Expand Up @@ -120,16 +134,3 @@ export const GoogleEmbedResponseTransform: (

return generateInvalidProviderResponseError(response, GOOGLE_VERTEX_AI);
};

export const VertexBatchEmbedConfig: ProviderConfig = {
input: {
param: 'content',
required: true,
transform: (value: EmbedParams) => {
if (typeof value.input === 'string') {
return value.input;
}
return value.input.map((item) => item).join('\n');
},
},
};
46 changes: 32 additions & 14 deletions src/providers/google-vertex-ai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,33 +20,34 @@ import {
import { chatCompleteParams, responseTransformers } from '../open-ai-base';
import { GOOGLE_VERTEX_AI } from '../../globals';
import { Params } from '../../types/requestBody';
import {
GoogleFileUploadRequestHandler,
GoogleFileUploadResponseTransform,
} from './uploadFile';
import {
GoogleBatchCreateConfig,
GoogleBatchCreateRequestTransform,
GoogleBatchCreateResponseTransform,
} from './createBatch';
import { GoogleRetrieveBatchResponseTransform } from './retrieveBatch';
import {
BatchOutputRequestHandler,
BatchOutputResponseTransform,
} from './getBatchOutput';
import { GoogleListBatchesResponseTransform } from './listBatches';
import { GoogleCancelBatchResponseTransform } from './cancelBatch';
import {
GoogleFileUploadRequestHandler,
GoogleFileUploadResponseTransform,
} from './uploadFile';
import { GoogleRetrieveBatchResponseTransform } from './retrieveBatch';
import {
GoogleFinetuneCreateResponseTransform,
GoogleVertexFinetuneConfig,
} from './createFinetune';
import { GoogleRetrieveFileContentResponseTransform } from './retrieveFileContent';
import { GoogleListFilesRequestHandler } from './listFiles';
import {
GoogleRetrieveFileRequestHandler,
GoogleRetrieveFileResponseTransform,
} from './retrieveFile';
import { GoogleFinetuneRetrieveResponseTransform } from './retrieveFinetune';
import { GoogleFinetuneListResponseTransform } from './listFinetunes';
import { GoogleListFilesRequestHandler } from './listFiles';
import { GoogleFinetuneRetrieveResponseTransform } from './retrieveFinetune';
import { GoogleRetrieveFileContentResponseTransform } from './retrieveFileContent';
import {
VertexAnthropicMessagesConfig,
VertexAnthropicMessagesResponseTransform,
Expand All @@ -60,7 +61,7 @@ import {

const VertexConfig: ProviderConfigs = {
api: VertexApiConfig,
getConfig: ({ params }) => {
getConfig: (params: Params) => {
const requestConfig = {
uploadFile: {},
createBatch: GoogleBatchCreateConfig,
Expand All @@ -76,20 +77,25 @@ const VertexConfig: ProviderConfigs = {
const responseTransforms = {
uploadFile: GoogleFileUploadResponseTransform,
retrieveBatch: GoogleRetrieveBatchResponseTransform,
retrieveFile: GoogleRetrieveFileResponseTransform,
getBatchOutput: BatchOutputResponseTransform,
listBatches: GoogleListBatchesResponseTransform,
cancelBatch: GoogleCancelBatchResponseTransform,
createBatch: GoogleBatchCreateResponseTransform,
retrieveFileContent: GoogleRetrieveFileContentResponseTransform,
retrieveFile: GoogleRetrieveFileResponseTransform,
createFinetune: GoogleFinetuneCreateResponseTransform,
retrieveFinetune: GoogleFinetuneRetrieveResponseTransform,
listFinetunes: GoogleFinetuneListResponseTransform,
createBatch: GoogleBatchCreateResponseTransform,
retrieveFileContent: GoogleRetrieveFileContentResponseTransform,
};

const requestTransforms = {
createBatch: GoogleBatchCreateRequestTransform,
};

const baseConfig = {
...requestConfig,
responseTransforms,
requestTransforms,
};

const providerModel = params?.model;
Expand All @@ -115,6 +121,9 @@ const VertexConfig: ProviderConfigs = {
imageGenerate: GoogleImageGenResponseTransform,
...responseTransforms,
},
requestTransforms: {
...baseConfig.requestTransforms,
},
};
case 'anthropic':
return {
Expand All @@ -131,18 +140,24 @@ const VertexConfig: ProviderConfigs = {
messages: VertexAnthropicMessagesResponseTransform,
...responseTransforms,
},
requestTransforms: {
...baseConfig.requestTransforms,
},
};
case 'meta':
return {
chatComplete: VertexLlamaChatCompleteConfig,
createBatch: GoogleBatchCreateConfig,
api: GoogleApiConfig,
createBatch: GoogleBatchCreateConfig,
createFinetune: baseConfig.createFinetune,
responseTransforms: {
chatComplete: VertexLlamaChatCompleteResponseTransform,
'stream-chatComplete': VertexLlamaChatCompleteStreamChunkTransform,
...responseTransforms,
},
requestTransforms: {
...baseConfig.requestTransforms,
},
};
case 'endpoints':
return {
Expand All @@ -160,14 +175,17 @@ const VertexConfig: ProviderConfigs = {
}
),
createBatch: GoogleBatchCreateConfig,
api: GoogleApiConfig,
createFinetune: baseConfig.createFinetune,
api: GoogleApiConfig,
responseTransforms: {
...responseTransformers(GOOGLE_VERTEX_AI, {
chatComplete: true,
}),
...responseTransforms,
},
requestTransforms: {
...baseConfig.requestTransforms,
},
};
case 'mistralai':
return {
Expand Down
2 changes: 1 addition & 1 deletion src/providers/google-vertex-ai/listBatches.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { GOOGLE_VERTEX_AI } from '../../globals';
import { generateInvalidProviderResponseError } from '../utils';
import { GoogleBatchRecord, GoogleErrorResponse } from './types';
import { generateInvalidProviderResponseError } from '../utils';
import { GoogleToOpenAIBatch } from './utils';

type GoogleListBatchesResponse = {
Expand Down
2 changes: 1 addition & 1 deletion src/providers/google-vertex-ai/messagesCountTokens.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ export const VertexAnthropicMessagesCountTokensConfig = {
param: 'model',
required: true,
transform: (params: MessageCreateParamsBase) => {
let model = params.model ?? '';
const model = params.model ?? '';
return model.replace('anthropic.', '');
},
},
Expand Down
Loading