|
6 | 6 | BedrockChatCompletionsParams,
|
7 | 7 | BedrockConverseCohereChatCompletionsParams,
|
8 | 8 | } from './chatComplete';
|
| 9 | +import { Context } from 'hono'; |
| 10 | +import { env } from 'hono/adapter'; |
9 | 11 |
|
10 | 12 | export const generateAWSHeaders = async (
|
11 | 13 | body: Record<string, any>,
|
@@ -154,3 +156,98 @@ export const transformAI21AdditionalModelRequestFields = (
|
154 | 156 | }
|
155 | 157 | return additionalModelRequestFields;
|
156 | 158 | };
|
| 159 | + |
| 160 | +export async function getAssumedRoleCredentials( |
| 161 | + c: Context, |
| 162 | + awsRoleArn: string, |
| 163 | + awsExternalId: string, |
| 164 | + awsRegion: string |
| 165 | +) { |
| 166 | + const cacheKey = `${awsRoleArn}/${awsExternalId}/${awsRegion}`; |
| 167 | + const getFromCacheByKey = c.get('getFromCacheByKey'); |
| 168 | + const putInCacheWithValue = c.get('putInCacheWithValue'); |
| 169 | + |
| 170 | + const { |
| 171 | + AWS_ASSUME_ROLE_ACCESS_KEY_ID, |
| 172 | + AWS_ASSUME_ROLE_SECRET_ACCESS_KEY, |
| 173 | + AWS_ASSUME_ROLE_REGION, |
| 174 | + } = env(c); |
| 175 | + const resp = getFromCacheByKey |
| 176 | + ? await getFromCacheByKey(env(c), cacheKey) |
| 177 | + : null; |
| 178 | + if (resp) { |
| 179 | + return resp; |
| 180 | + } |
| 181 | + // Long-term credentials to assume role, static values from ENV |
| 182 | + const accessKeyId: string = AWS_ASSUME_ROLE_ACCESS_KEY_ID || ''; |
| 183 | + const secretAccessKey: string = AWS_ASSUME_ROLE_SECRET_ACCESS_KEY || ''; |
| 184 | + const region = awsRegion || AWS_ASSUME_ROLE_REGION || 'us-east-1'; |
| 185 | + const service = 'sts'; |
| 186 | + const hostname = `sts.${region}.amazonaws.com`; |
| 187 | + const signer = new SignatureV4({ |
| 188 | + service, |
| 189 | + region, |
| 190 | + credentials: { |
| 191 | + accessKeyId, |
| 192 | + secretAccessKey, |
| 193 | + }, |
| 194 | + sha256: Sha256, |
| 195 | + }); |
| 196 | + const url = `https://${hostname}?Action=AssumeRole&Version=2011-06-15&RoleArn=${awsRoleArn}&ExternalId=${awsExternalId}&RoleSessionName=random`; |
| 197 | + const urlObj = new URL(url); |
| 198 | + const requestHeaders = { host: hostname }; |
| 199 | + const options = { |
| 200 | + method: 'GET', |
| 201 | + path: urlObj.pathname, |
| 202 | + protocol: urlObj.protocol, |
| 203 | + hostname: urlObj.hostname, |
| 204 | + headers: requestHeaders, |
| 205 | + query: Object.fromEntries(urlObj.searchParams), |
| 206 | + }; |
| 207 | + const { headers } = await signer.sign(options); |
| 208 | + |
| 209 | + let credentials: any; |
| 210 | + try { |
| 211 | + const response = await fetch(url, { |
| 212 | + method: 'GET', |
| 213 | + headers: headers, |
| 214 | + }); |
| 215 | + |
| 216 | + if (!response.ok) { |
| 217 | + const resp = await response.text(); |
| 218 | + console.error({ message: resp }); |
| 219 | + throw new Error(`HTTP error! status: ${response.status}`); |
| 220 | + } |
| 221 | + |
| 222 | + const xmlData = await response.text(); |
| 223 | + credentials = parseXml(xmlData); |
| 224 | + if (putInCacheWithValue) { |
| 225 | + putInCacheWithValue(env(c), cacheKey, credentials, 60); //1 minute |
| 226 | + } |
| 227 | + } catch (error) { |
| 228 | + console.error({ message: `Error assuming role:, ${error}` }); |
| 229 | + } |
| 230 | + |
| 231 | + return credentials; |
| 232 | +} |
| 233 | + |
| 234 | +function parseXml(xml: string) { |
| 235 | + // Simple XML parser for this specific use case |
| 236 | + const getTagContent = (tag: string) => { |
| 237 | + const regex = new RegExp(`<${tag}>(.*?)</${tag}>`, 's'); |
| 238 | + const match = xml.match(regex); |
| 239 | + return match ? match[1] : null; |
| 240 | + }; |
| 241 | + |
| 242 | + const credentials = getTagContent('Credentials'); |
| 243 | + if (!credentials) { |
| 244 | + throw new Error('Failed to parse Credentials from XML response'); |
| 245 | + } |
| 246 | + |
| 247 | + return { |
| 248 | + accessKeyId: getTagContent('AccessKeyId'), |
| 249 | + secretAccessKey: getTagContent('SecretAccessKey'), |
| 250 | + sessionToken: getTagContent('SessionToken'), |
| 251 | + expiration: getTagContent('Expiration'), |
| 252 | + }; |
| 253 | +} |
0 commit comments