Skip to content

Commit 5e2fcdd

Browse files
authored
Merge pull request #711 from Portkey-AI/feat/aws-assume-role-setup
Feat: aws assume role setup
2 parents 2d3c874 + bd20558 commit 5e2fcdd

File tree

5 files changed

+122
-1
lines changed

5 files changed

+122
-1
lines changed

src/handlers/handlerUtils.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ export async function tryPostProxy(
269269
: (providerOption.urlToFetch as string);
270270

271271
const headers = await apiConfig.headers({
272+
c,
272273
providerOptions: providerOption,
273274
fn,
274275
transformedRequestBody: params,
@@ -520,6 +521,7 @@ export async function tryPost(
520521
const url = `${baseUrl}${endpoint}`;
521522

522523
const headers = await apiConfig.headers({
524+
c,
523525
providerOptions: providerOption,
524526
fn,
525527
transformedRequestBody,
@@ -1037,6 +1039,9 @@ export function constructConfigFromRequestHeaders(
10371039
awsSecretAccessKey: requestHeaders[`x-${POWERED_BY}-aws-secret-access-key`],
10381040
awsSessionToken: requestHeaders[`x-${POWERED_BY}-aws-session-token`],
10391041
awsRegion: requestHeaders[`x-${POWERED_BY}-aws-region`],
1042+
awsRoleArn: requestHeaders[`x-${POWERED_BY}-aws-role-arn`],
1043+
awsAuthType: requestHeaders[`x-${POWERED_BY}-aws-auth-type`],
1044+
awsExternalId: requestHeaders[`x-${POWERED_BY}-aws-external-id`],
10401045
};
10411046

10421047
const workersAiConfig = {

src/providers/bedrock/api.ts

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import { GatewayError } from '../../errors/GatewayError';
22
import { ProviderAPIConfig } from '../types';
33
import { bedrockInvokeModels } from './constants';
4-
import { generateAWSHeaders } from './utils';
4+
import { generateAWSHeaders, getAssumedRoleCredentials } from './utils';
55

66
const BedrockAPIConfig: ProviderAPIConfig = {
77
getBaseURL: ({ providerOptions }) =>
88
`https://bedrock-runtime.${providerOptions.awsRegion || 'us-east-1'}.amazonaws.com`,
99
headers: async ({
10+
c,
1011
providerOptions,
1112
transformedRequestBody,
1213
transformedRequestUrl,
@@ -15,6 +16,19 @@ const BedrockAPIConfig: ProviderAPIConfig = {
1516
'content-type': 'application/json',
1617
};
1718

19+
if (providerOptions.awsAuthType === 'assumedRole') {
20+
const { accessKeyId, secretAccessKey, sessionToken } =
21+
(await getAssumedRoleCredentials(
22+
c,
23+
providerOptions.awsRoleArn || '',
24+
providerOptions.awsExternalId || '',
25+
providerOptions.awsRegion || ''
26+
)) || {};
27+
providerOptions.awsAccessKeyId = accessKeyId;
28+
providerOptions.awsSecretAccessKey = secretAccessKey;
29+
providerOptions.awsSessionToken = sessionToken;
30+
}
31+
1832
return generateAWSHeaders(
1933
transformedRequestBody,
2034
headers,

src/providers/bedrock/utils.ts

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import {
66
BedrockChatCompletionsParams,
77
BedrockConverseCohereChatCompletionsParams,
88
} from './chatComplete';
9+
import { Context } from 'hono';
10+
import { env } from 'hono/adapter';
911

1012
export const generateAWSHeaders = async (
1113
body: Record<string, any>,
@@ -154,3 +156,98 @@ export const transformAI21AdditionalModelRequestFields = (
154156
}
155157
return additionalModelRequestFields;
156158
};
159+
160+
export async function getAssumedRoleCredentials(
161+
c: Context,
162+
awsRoleArn: string,
163+
awsExternalId: string,
164+
awsRegion: string
165+
) {
166+
const cacheKey = `${awsRoleArn}/${awsExternalId}/${awsRegion}`;
167+
const getFromCacheByKey = c.get('getFromCacheByKey');
168+
const putInCacheWithValue = c.get('putInCacheWithValue');
169+
170+
const {
171+
AWS_ASSUME_ROLE_ACCESS_KEY_ID,
172+
AWS_ASSUME_ROLE_SECRET_ACCESS_KEY,
173+
AWS_ASSUME_ROLE_REGION,
174+
} = env(c);
175+
const resp = getFromCacheByKey
176+
? await getFromCacheByKey(env(c), cacheKey)
177+
: null;
178+
if (resp) {
179+
return resp;
180+
}
181+
// Long-term credentials to assume role, static values from ENV
182+
const accessKeyId: string = AWS_ASSUME_ROLE_ACCESS_KEY_ID || '';
183+
const secretAccessKey: string = AWS_ASSUME_ROLE_SECRET_ACCESS_KEY || '';
184+
const region = awsRegion || AWS_ASSUME_ROLE_REGION || 'us-east-1';
185+
const service = 'sts';
186+
const hostname = `sts.${region}.amazonaws.com`;
187+
const signer = new SignatureV4({
188+
service,
189+
region,
190+
credentials: {
191+
accessKeyId,
192+
secretAccessKey,
193+
},
194+
sha256: Sha256,
195+
});
196+
const url = `https://${hostname}?Action=AssumeRole&Version=2011-06-15&RoleArn=${awsRoleArn}&ExternalId=${awsExternalId}&RoleSessionName=random`;
197+
const urlObj = new URL(url);
198+
const requestHeaders = { host: hostname };
199+
const options = {
200+
method: 'GET',
201+
path: urlObj.pathname,
202+
protocol: urlObj.protocol,
203+
hostname: urlObj.hostname,
204+
headers: requestHeaders,
205+
query: Object.fromEntries(urlObj.searchParams),
206+
};
207+
const { headers } = await signer.sign(options);
208+
209+
let credentials: any;
210+
try {
211+
const response = await fetch(url, {
212+
method: 'GET',
213+
headers: headers,
214+
});
215+
216+
if (!response.ok) {
217+
const resp = await response.text();
218+
console.error({ message: resp });
219+
throw new Error(`HTTP error! status: ${response.status}`);
220+
}
221+
222+
const xmlData = await response.text();
223+
credentials = parseXml(xmlData);
224+
if (putInCacheWithValue) {
225+
putInCacheWithValue(env(c), cacheKey, credentials, 60); //1 minute
226+
}
227+
} catch (error) {
228+
console.error({ message: `Error assuming role:, ${error}` });
229+
}
230+
231+
return credentials;
232+
}
233+
234+
function parseXml(xml: string) {
235+
// Simple XML parser for this specific use case
236+
const getTagContent = (tag: string) => {
237+
const regex = new RegExp(`<${tag}>(.*?)</${tag}>`, 's');
238+
const match = xml.match(regex);
239+
return match ? match[1] : null;
240+
};
241+
242+
const credentials = getTagContent('Credentials');
243+
if (!credentials) {
244+
throw new Error('Failed to parse Credentials from XML response');
245+
}
246+
247+
return {
248+
accessKeyId: getTagContent('AccessKeyId'),
249+
secretAccessKey: getTagContent('SecretAccessKey'),
250+
sessionToken: getTagContent('SessionToken'),
251+
expiration: getTagContent('Expiration'),
252+
};
253+
}

src/providers/types.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { Context } from 'hono';
12
import { Message, Options, Params } from '../types/requestBody';
23

34
/**
@@ -35,6 +36,7 @@ export interface ProviderConfig {
3536
export interface ProviderAPIConfig {
3637
/** A function to generate the headers for the API request. */
3738
headers: (args: {
39+
c: Context;
3840
providerOptions: Options;
3941
fn: string;
4042
transformedRequestBody: Record<string, any>;

src/types/requestBody.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ export interface Options {
7878
awsAccessKeyId?: string;
7979
awsSessionToken?: string;
8080
awsRegion?: string;
81+
awsAuthType?: string;
82+
awsRoleArn?: string;
83+
awsExternalId?: string;
8184

8285
/** Stability AI specific */
8386
stabilityClientId?: string;

0 commit comments

Comments
 (0)