Skip to content

Commit ad35209

Browse files
authored
Merge branch 'main' into feat/azure_entraid_integration
2 parents e17a75b + fa25367 commit ad35209

File tree

9 files changed

+198
-22
lines changed

9 files changed

+198
-22
lines changed

package-lock.json

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@portkey-ai/gateway",
3-
"version": "1.7.7",
3+
"version": "1.8.0",
44
"description": "A fast AI gateway by Portkey",
55
"repository": {
66
"type": "git",

src/handlers/handlerUtils.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ export async function tryPostProxy(
269269
: (providerOption.urlToFetch as string);
270270

271271
const headers = await apiConfig.headers({
272+
c,
272273
providerOptions: providerOption,
273274
fn,
274275
transformedRequestBody: params,
@@ -520,6 +521,7 @@ export async function tryPost(
520521
const url = `${baseUrl}${endpoint}`;
521522

522523
const headers = await apiConfig.headers({
524+
c,
523525
providerOptions: providerOption,
524526
fn,
525527
transformedRequestBody,
@@ -1044,6 +1046,9 @@ export function constructConfigFromRequestHeaders(
10441046
awsSecretAccessKey: requestHeaders[`x-${POWERED_BY}-aws-secret-access-key`],
10451047
awsSessionToken: requestHeaders[`x-${POWERED_BY}-aws-session-token`],
10461048
awsRegion: requestHeaders[`x-${POWERED_BY}-aws-region`],
1049+
awsRoleArn: requestHeaders[`x-${POWERED_BY}-aws-role-arn`],
1050+
awsAuthType: requestHeaders[`x-${POWERED_BY}-aws-auth-type`],
1051+
awsExternalId: requestHeaders[`x-${POWERED_BY}-aws-external-id`],
10471052
};
10481053

10491054
const workersAiConfig = {

src/providers/anthropic/api.ts

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,20 @@ import { ProviderAPIConfig } from '../types';
22

33
const AnthropicAPIConfig: ProviderAPIConfig = {
44
getBaseURL: () => 'https://api.anthropic.com/v1',
5-
headers: ({ providerOptions, fn }) => {
5+
headers: ({ providerOptions, fn, gatewayRequestBody }) => {
66
const headers: Record<string, string> = {
77
'X-API-Key': `${providerOptions.apiKey}`,
88
};
99

10+
// Accept anthropic_beta and anthropic_version in body to support enviroments which cannot send it in headers.
1011
const betaHeader =
11-
providerOptions?.['anthropicBeta'] ?? 'messages-2023-12-15';
12-
const version = providerOptions?.['anthropicVersion'] ?? '2023-06-01';
12+
providerOptions?.['anthropicBeta'] ??
13+
gatewayRequestBody?.['anthropic_beta'] ??
14+
'messages-2023-12-15';
15+
const version =
16+
providerOptions?.['anthropicVersion'] ??
17+
gatewayRequestBody?.['anthropic_version'] ??
18+
'2023-06-01';
1319

1420
if (fn === 'chatComplete') {
1521
headers['anthropic-beta'] = betaHeader;

src/providers/bedrock/api.ts

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1+
import { env } from 'hono/adapter';
12
import { GatewayError } from '../../errors/GatewayError';
23
import { ProviderAPIConfig } from '../types';
34
import { bedrockInvokeModels } from './constants';
4-
import { generateAWSHeaders } from './utils';
5+
import { generateAWSHeaders, getAssumedRoleCredentials } from './utils';
56

67
const BedrockAPIConfig: ProviderAPIConfig = {
78
getBaseURL: ({ providerOptions }) =>
89
`https://bedrock-runtime.${providerOptions.awsRegion || 'us-east-1'}.amazonaws.com`,
910
headers: async ({
11+
c,
1012
providerOptions,
1113
transformedRequestBody,
1214
transformedRequestUrl,
@@ -15,6 +17,41 @@ const BedrockAPIConfig: ProviderAPIConfig = {
1517
'content-type': 'application/json',
1618
};
1719

20+
if (providerOptions.awsAuthType === 'assumedRole') {
21+
try {
22+
// Assume the role in the source account
23+
const sourceRoleCredentials = await getAssumedRoleCredentials(
24+
c,
25+
env(c).AWS_ASSUME_ROLE_SOURCE_ARN, // Role ARN in the source account
26+
env(c).AWS_ASSUME_ROLE_SOURCE_EXTERNAL_ID || '', // External ID for source role (if needed)
27+
providerOptions.awsRegion || ''
28+
);
29+
30+
if (!sourceRoleCredentials) {
31+
throw new Error('Server Error while assuming internal role');
32+
}
33+
34+
// Assume role in destination account using temporary creds obtained in first step
35+
const { accessKeyId, secretAccessKey, sessionToken } =
36+
(await getAssumedRoleCredentials(
37+
c,
38+
providerOptions.awsRoleArn || '',
39+
providerOptions.awsExternalId || '',
40+
providerOptions.awsRegion || '',
41+
{
42+
accessKeyId: sourceRoleCredentials.accessKeyId,
43+
secretAccessKey: sourceRoleCredentials.secretAccessKey,
44+
sessionToken: sourceRoleCredentials.sessionToken,
45+
}
46+
)) || {};
47+
providerOptions.awsAccessKeyId = accessKeyId;
48+
providerOptions.awsSecretAccessKey = secretAccessKey;
49+
providerOptions.awsSessionToken = sessionToken;
50+
} catch (e) {
51+
throw new GatewayError('Error while assuming bedrock role');
52+
}
53+
}
54+
1855
return generateAWSHeaders(
1956
transformedRequestBody,
2057
headers,

src/providers/bedrock/utils.ts

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import {
66
BedrockChatCompletionsParams,
77
BedrockConverseCohereChatCompletionsParams,
88
} from './chatComplete';
9+
import { Context } from 'hono';
10+
import { env } from 'hono/adapter';
911

1012
export const generateAWSHeaders = async (
1113
body: Record<string, any>,
@@ -154,3 +156,116 @@ export const transformAI21AdditionalModelRequestFields = (
154156
}
155157
return additionalModelRequestFields;
156158
};
159+
160+
export async function getAssumedRoleCredentials(
161+
c: Context,
162+
awsRoleArn: string,
163+
awsExternalId: string,
164+
awsRegion: string,
165+
creds?: {
166+
accessKeyId: string;
167+
secretAccessKey: string;
168+
sessionToken?: string;
169+
}
170+
) {
171+
const cacheKey = `${awsRoleArn}/${awsExternalId}/${awsRegion}`;
172+
const getFromCacheByKey = c.get('getFromCacheByKey');
173+
const putInCacheWithValue = c.get('putInCacheWithValue');
174+
175+
const resp = getFromCacheByKey
176+
? await getFromCacheByKey(env(c), cacheKey)
177+
: null;
178+
if (resp) {
179+
return resp;
180+
}
181+
182+
// Determine which credentials to use
183+
let accessKeyId: string;
184+
let secretAccessKey: string;
185+
let sessionToken: string | undefined;
186+
187+
if (creds) {
188+
// Use provided credentials
189+
accessKeyId = creds.accessKeyId;
190+
secretAccessKey = creds.secretAccessKey;
191+
sessionToken = creds.sessionToken;
192+
} else {
193+
// Use environment credentials
194+
const { AWS_ASSUME_ROLE_ACCESS_KEY_ID, AWS_ASSUME_ROLE_SECRET_ACCESS_KEY } =
195+
env(c);
196+
accessKeyId = AWS_ASSUME_ROLE_ACCESS_KEY_ID || '';
197+
secretAccessKey = AWS_ASSUME_ROLE_SECRET_ACCESS_KEY || '';
198+
}
199+
200+
const region = awsRegion || 'us-east-1';
201+
const service = 'sts';
202+
const hostname = `sts.${region}.amazonaws.com`;
203+
const signer = new SignatureV4({
204+
service,
205+
region,
206+
credentials: {
207+
accessKeyId,
208+
secretAccessKey,
209+
sessionToken,
210+
},
211+
sha256: Sha256,
212+
});
213+
const date = new Date();
214+
const sessionName = `${date.getFullYear()}${date.getMonth()}${date.getDay()}`;
215+
const url = `https://${hostname}?Action=AssumeRole&Version=2011-06-15&RoleArn=${awsRoleArn}&RoleSessionName=${sessionName}${awsExternalId ? `&ExternalId=${awsExternalId}` : ''}`;
216+
const urlObj = new URL(url);
217+
const requestHeaders = { host: hostname };
218+
const options = {
219+
method: 'GET',
220+
path: urlObj.pathname,
221+
protocol: urlObj.protocol,
222+
hostname: urlObj.hostname,
223+
headers: requestHeaders,
224+
query: Object.fromEntries(urlObj.searchParams),
225+
};
226+
const { headers } = await signer.sign(options);
227+
228+
let credentials: any;
229+
try {
230+
const response = await fetch(url, {
231+
method: 'GET',
232+
headers: headers,
233+
});
234+
235+
if (!response.ok) {
236+
const resp = await response.text();
237+
console.error({ message: resp });
238+
throw new Error(`HTTP error! status: ${response.status}`);
239+
}
240+
241+
const xmlData = await response.text();
242+
credentials = parseXml(xmlData);
243+
if (putInCacheWithValue) {
244+
putInCacheWithValue(env(c), cacheKey, credentials, 60); //1 minute
245+
}
246+
} catch (error) {
247+
console.error({ message: `Error assuming role:, ${error}` });
248+
}
249+
return credentials;
250+
}
251+
252+
function parseXml(xml: string) {
253+
// Simple XML parser for this specific use case
254+
const getTagContent = (tag: string) => {
255+
const regex = new RegExp(`<${tag}>(.*?)</${tag}>`, 's');
256+
const match = xml.match(regex);
257+
return match ? match[1] : null;
258+
};
259+
260+
const credentials = getTagContent('Credentials');
261+
if (!credentials) {
262+
throw new Error('Failed to parse Credentials from XML response');
263+
}
264+
265+
return {
266+
accessKeyId: getTagContent('AccessKeyId'),
267+
secretAccessKey: getTagContent('SecretAccessKey'),
268+
sessionToken: getTagContent('SessionToken'),
269+
expiration: getTagContent('Expiration'),
270+
};
271+
}

src/providers/ollama/chatComplete.ts

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ import {
44
ProviderConfig,
55
} from '../types';
66
import { OLLAMA } from '../../globals';
7-
import { generateErrorResponse } from '../utils';
7+
import {
8+
generateErrorResponse,
9+
generateInvalidProviderResponseError,
10+
} from '../utils';
811

912
export const OllamaChatCompleteConfig: ProviderConfig = {
1013
model: {
@@ -63,9 +66,7 @@ export const OllamaChatCompleteConfig: ProviderConfig = {
6366
},
6467
};
6568

66-
export interface OllamaChatCompleteResponse
67-
extends ChatCompletionResponse,
68-
ErrorResponse {
69+
export interface OllamaChatCompleteResponse extends ChatCompletionResponse {
6970
system_fingerprint: string;
7071
}
7172

@@ -86,10 +87,10 @@ export interface OllamaStreamChunk {
8687
}
8788

8889
export const OllamaChatCompleteResponseTransform: (
89-
response: OllamaChatCompleteResponse,
90+
response: OllamaChatCompleteResponse | ErrorResponse,
9091
responseStatus: number
9192
) => ChatCompletionResponse | ErrorResponse = (response, responseStatus) => {
92-
if (responseStatus !== 200) {
93+
if (responseStatus !== 200 && 'error' in response) {
9394
return generateErrorResponse(
9495
{
9596
message: response.error?.message,
@@ -101,15 +102,19 @@ export const OllamaChatCompleteResponseTransform: (
101102
);
102103
}
103104

104-
return {
105-
id: response.id,
106-
object: response.object,
107-
created: response.created,
108-
model: response.model,
109-
provider: OLLAMA,
110-
choices: response.choices,
111-
usage: response.usage,
112-
};
105+
if ('choices' in response) {
106+
return {
107+
id: response.id,
108+
object: response.object,
109+
created: response.created,
110+
model: response.model,
111+
provider: OLLAMA,
112+
choices: response.choices,
113+
usage: response.usage,
114+
};
115+
}
116+
117+
return generateInvalidProviderResponseError(response, OLLAMA);
113118
};
114119

115120
export const OllamaChatCompleteStreamChunkTransform: (

src/providers/types.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { Context } from 'hono';
12
import { Message, Options, Params } from '../types/requestBody';
23

34
/**
@@ -35,6 +36,7 @@ export interface ProviderConfig {
3536
export interface ProviderAPIConfig {
3637
/** A function to generate the headers for the API request. */
3738
headers: (args: {
39+
c: Context;
3840
providerOptions: Options;
3941
fn: string;
4042
transformedRequestBody: Record<string, any>;

src/types/requestBody.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ export interface Options {
8383
awsAccessKeyId?: string;
8484
awsSessionToken?: string;
8585
awsRegion?: string;
86+
awsAuthType?: string;
87+
awsRoleArn?: string;
88+
awsExternalId?: string;
8689

8790
/** Stability AI specific */
8891
stabilityClientId?: string;
@@ -325,6 +328,9 @@ export interface Params {
325328
};
326329
// Google Vertex AI specific
327330
safety_settings?: any;
331+
// Anthropic specific
332+
anthropic_beta?: string;
333+
anthropic_version?: string;
328334
}
329335

330336
interface Examples {

0 commit comments

Comments
 (0)