|
6 | 6 | BedrockChatCompletionsParams,
|
7 | 7 | BedrockConverseCohereChatCompletionsParams,
|
8 | 8 | } from './chatComplete';
|
| 9 | +import { Context } from 'hono'; |
| 10 | +import { env } from 'hono/adapter'; |
9 | 11 |
|
10 | 12 | export const generateAWSHeaders = async (
|
11 | 13 | body: Record<string, any>,
|
@@ -154,3 +156,116 @@ export const transformAI21AdditionalModelRequestFields = (
|
154 | 156 | }
|
155 | 157 | return additionalModelRequestFields;
|
156 | 158 | };
|
| 159 | + |
| 160 | +export async function getAssumedRoleCredentials( |
| 161 | + c: Context, |
| 162 | + awsRoleArn: string, |
| 163 | + awsExternalId: string, |
| 164 | + awsRegion: string, |
| 165 | + creds?: { |
| 166 | + accessKeyId: string; |
| 167 | + secretAccessKey: string; |
| 168 | + sessionToken?: string; |
| 169 | + } |
| 170 | +) { |
| 171 | + const cacheKey = `${awsRoleArn}/${awsExternalId}/${awsRegion}`; |
| 172 | + const getFromCacheByKey = c.get('getFromCacheByKey'); |
| 173 | + const putInCacheWithValue = c.get('putInCacheWithValue'); |
| 174 | + |
| 175 | + const resp = getFromCacheByKey |
| 176 | + ? await getFromCacheByKey(env(c), cacheKey) |
| 177 | + : null; |
| 178 | + if (resp) { |
| 179 | + return resp; |
| 180 | + } |
| 181 | + |
| 182 | + // Determine which credentials to use |
| 183 | + let accessKeyId: string; |
| 184 | + let secretAccessKey: string; |
| 185 | + let sessionToken: string | undefined; |
| 186 | + |
| 187 | + if (creds) { |
| 188 | + // Use provided credentials |
| 189 | + accessKeyId = creds.accessKeyId; |
| 190 | + secretAccessKey = creds.secretAccessKey; |
| 191 | + sessionToken = creds.sessionToken; |
| 192 | + } else { |
| 193 | + // Use environment credentials |
| 194 | + const { AWS_ASSUME_ROLE_ACCESS_KEY_ID, AWS_ASSUME_ROLE_SECRET_ACCESS_KEY } = |
| 195 | + env(c); |
| 196 | + accessKeyId = AWS_ASSUME_ROLE_ACCESS_KEY_ID || ''; |
| 197 | + secretAccessKey = AWS_ASSUME_ROLE_SECRET_ACCESS_KEY || ''; |
| 198 | + } |
| 199 | + |
| 200 | + const region = awsRegion || 'us-east-1'; |
| 201 | + const service = 'sts'; |
| 202 | + const hostname = `sts.${region}.amazonaws.com`; |
| 203 | + const signer = new SignatureV4({ |
| 204 | + service, |
| 205 | + region, |
| 206 | + credentials: { |
| 207 | + accessKeyId, |
| 208 | + secretAccessKey, |
| 209 | + sessionToken, |
| 210 | + }, |
| 211 | + sha256: Sha256, |
| 212 | + }); |
| 213 | + const date = new Date(); |
| 214 | + const sessionName = `${date.getFullYear()}${date.getMonth()}${date.getDay()}`; |
| 215 | + const url = `https://${hostname}?Action=AssumeRole&Version=2011-06-15&RoleArn=${awsRoleArn}&RoleSessionName=${sessionName}${awsExternalId ? `&ExternalId=${awsExternalId}` : ''}`; |
| 216 | + const urlObj = new URL(url); |
| 217 | + const requestHeaders = { host: hostname }; |
| 218 | + const options = { |
| 219 | + method: 'GET', |
| 220 | + path: urlObj.pathname, |
| 221 | + protocol: urlObj.protocol, |
| 222 | + hostname: urlObj.hostname, |
| 223 | + headers: requestHeaders, |
| 224 | + query: Object.fromEntries(urlObj.searchParams), |
| 225 | + }; |
| 226 | + const { headers } = await signer.sign(options); |
| 227 | + |
| 228 | + let credentials: any; |
| 229 | + try { |
| 230 | + const response = await fetch(url, { |
| 231 | + method: 'GET', |
| 232 | + headers: headers, |
| 233 | + }); |
| 234 | + |
| 235 | + if (!response.ok) { |
| 236 | + const resp = await response.text(); |
| 237 | + console.error({ message: resp }); |
| 238 | + throw new Error(`HTTP error! status: ${response.status}`); |
| 239 | + } |
| 240 | + |
| 241 | + const xmlData = await response.text(); |
| 242 | + credentials = parseXml(xmlData); |
| 243 | + if (putInCacheWithValue) { |
| 244 | + putInCacheWithValue(env(c), cacheKey, credentials, 60); //1 minute |
| 245 | + } |
| 246 | + } catch (error) { |
| 247 | + console.error({ message: `Error assuming role:, ${error}` }); |
| 248 | + } |
| 249 | + return credentials; |
| 250 | +} |
| 251 | + |
| 252 | +function parseXml(xml: string) { |
| 253 | + // Simple XML parser for this specific use case |
| 254 | + const getTagContent = (tag: string) => { |
| 255 | + const regex = new RegExp(`<${tag}>(.*?)</${tag}>`, 's'); |
| 256 | + const match = xml.match(regex); |
| 257 | + return match ? match[1] : null; |
| 258 | + }; |
| 259 | + |
| 260 | + const credentials = getTagContent('Credentials'); |
| 261 | + if (!credentials) { |
| 262 | + throw new Error('Failed to parse Credentials from XML response'); |
| 263 | + } |
| 264 | + |
| 265 | + return { |
| 266 | + accessKeyId: getTagContent('AccessKeyId'), |
| 267 | + secretAccessKey: getTagContent('SecretAccessKey'), |
| 268 | + sessionToken: getTagContent('SessionToken'), |
| 269 | + expiration: getTagContent('Expiration'), |
| 270 | + }; |
| 271 | +} |
0 commit comments