Skip to content

Commit 2835e88

Browse files
authored
feat: distribute bedrock requests per regional limits (supabase#37049)
1 parent 140d602 commit 2835e88

File tree

6 files changed

+201
-128
lines changed

6 files changed

+201
-128
lines changed

apps/studio/lib/ai/bedrock.test.ts

Lines changed: 0 additions & 73 deletions
This file was deleted.

apps/studio/lib/ai/bedrock.ts

Lines changed: 65 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import { createAmazonBedrock } from '@ai-sdk/amazon-bedrock'
22
import { createCredentialChain, fromNodeProviderChain } from '@aws-sdk/credential-providers'
33
import { CredentialsProviderError } from '@smithy/property-provider'
44
import { awsCredentialsProvider } from '@vercel/functions/oidc'
5+
import { selectWeightedKey } from './util'
56

67
const credentialProvider = createCredentialChain(
78
// Vercel OIDC provider will be used for staging/production
@@ -34,45 +35,82 @@ async function vercelOidcProvider() {
3435
}
3536
}
3637

38+
export async function checkAwsCredentials() {
39+
try {
40+
const credentials = await credentialProvider()
41+
return !!credentials
42+
} catch (error) {
43+
return false
44+
}
45+
}
46+
3747
export const bedrockRegionMap = {
38-
us1: 'us-east-1',
39-
us3: 'us-west-2',
48+
use1: 'us-east-1',
49+
use2: 'us-east-2',
50+
usw2: 'us-west-2',
51+
euc1: 'eu-central-1',
4052
} as const
4153

4254
export type BedrockRegion = keyof typeof bedrockRegionMap
4355

44-
export const bedrockForRegion = (region: BedrockRegion) =>
45-
createAmazonBedrock({
46-
credentialProvider,
47-
region: bedrockRegionMap[region],
48-
})
56+
export const regionPrefixMap: Record<BedrockRegion, string> = {
57+
use1: 'us',
58+
use2: 'us',
59+
usw2: 'us',
60+
euc1: 'eu',
61+
}
62+
63+
export type BedrockModel =
64+
| 'anthropic.claude-3-7-sonnet-20250219-v1:0'
65+
| 'anthropic.claude-3-5-haiku-20241022-v1:0'
66+
67+
export type RegionWeights = Record<BedrockRegion, number>
4968

5069
/**
51-
* Selects a region based on a routing key using a consistent hashing algorithm.
70+
* Weights for distributing requests across Bedrock regions.
71+
* Weights are proportional to our rate limits per model per region.
72+
*/
73+
const modelRegionWeights: Record<BedrockModel, RegionWeights> = {
74+
['anthropic.claude-3-7-sonnet-20250219-v1:0']: {
75+
use1: 40,
76+
use2: 10,
77+
usw2: 10,
78+
euc1: 10,
79+
},
80+
['anthropic.claude-3-5-haiku-20241022-v1:0']: {
81+
use1: 40,
82+
use2: 0,
83+
usw2: 40,
84+
euc1: 0,
85+
},
86+
}
87+
88+
/**
89+
* Creates a Bedrock client that routes requests to different regions
90+
* based on a routing key.
5291
*
53-
* Ensures that the same key always maps to the same region
54-
* while distributing keys evenly across available regions.
92+
* Used to load balance requests across multiple regions depending on
93+
* their capacities.
5594
*/
56-
export async function selectBedrockRegion(routingKey: string) {
57-
const regions = Object.keys(bedrockRegionMap) as BedrockRegion[]
58-
const encoder = new TextEncoder()
59-
const data = encoder.encode(routingKey)
60-
const hashBuffer = await crypto.subtle.digest('SHA-256', data)
95+
export function createRoutedBedrock(routingKey?: string) {
96+
return async (modelId: BedrockModel) => {
97+
const regionWeights = modelRegionWeights[modelId]
6198

62-
// Use first 4 bytes (32 bit integer)
63-
const hashInt = new DataView(hashBuffer).getUint32(0)
99+
// Select the Bedrock region based on the routing key and the model
100+
const bedrockRegion = routingKey
101+
? await selectWeightedKey(routingKey, regionWeights)
102+
: // There's a few places where getModel is called without a routing key
103+
// Will cause disproportionate load on use1 region
104+
'use1'
64105

65-
// Use modulo to map to available regions
66-
const regionIndex = hashInt % regions.length
106+
const bedrock = createAmazonBedrock({
107+
credentialProvider,
108+
region: bedrockRegionMap[bedrockRegion],
109+
})
67110

68-
return regions[regionIndex]
69-
}
111+
// Cross-region models require the region prefix
112+
const modelName = `${regionPrefixMap[bedrockRegion]}.${modelId}`
70113

71-
export async function checkAwsCredentials() {
72-
try {
73-
const credentials = await credentialProvider()
74-
return !!credentials
75-
} catch (error) {
76-
return false
114+
return bedrock(modelName)
77115
}
78116
}

apps/studio/lib/ai/model.test.ts

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ vi.mock('@ai-sdk/openai', () => ({
77
openai: vi.fn(() => 'openai-model'),
88
}))
99

10-
vi.mock('./bedrock', () => ({
11-
bedrockForRegion: vi.fn(() => () => 'bedrock-model'),
10+
vi.mock('./bedrock', async () => ({
11+
...(await vi.importActual('./bedrock')),
12+
createRoutedBedrock: vi.fn(() => () => 'bedrock-model'),
1213
checkAwsCredentials: vi.fn(),
13-
selectBedrockRegion: vi.fn(() => 'us'),
1414
}))
1515

1616
describe('getModel', () => {
@@ -29,18 +29,15 @@ describe('getModel', () => {
2929

3030
const { model, error } = await getModel()
3131

32-
console.log('Model:', model)
33-
3432
expect(model).toEqual('bedrock-model')
35-
expect(bedrockModule.bedrockForRegion).toHaveBeenCalledWith('us1')
3633
expect(error).toBeUndefined()
3734
})
3835

3936
it('should return OpenAI model when AWS credentials are not available but OPENAI_API_KEY is set', async () => {
4037
vi.mocked(bedrockModule.checkAwsCredentials).mockResolvedValue(false)
4138
process.env.OPENAI_API_KEY = 'test-key'
4239

43-
const { model } = await getModel('test-key')
40+
const { model } = await getModel()
4441

4542
expect(model).toEqual('openai-model')
4643
expect(openai).toHaveBeenCalledWith('gpt-4.1-2025-04-14')

apps/studio/lib/ai/model.ts

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,12 @@
11
import { openai } from '@ai-sdk/openai'
22
import { LanguageModel } from 'ai'
3-
import {
4-
bedrockForRegion,
5-
BedrockRegion,
6-
checkAwsCredentials,
7-
selectBedrockRegion,
8-
} from './bedrock'
9-
10-
export const regionMap = {
11-
us1: 'us',
12-
us2: 'us',
13-
us3: 'us',
14-
eu: 'eu',
15-
}
3+
import { checkAwsCredentials, createRoutedBedrock } from './bedrock'
164

175
// Default behaviour here is to be throttled (e.g if this env var is not available, IS_THROTTLED should be true, unless specified 'false')
186
const IS_THROTTLED = process.env.IS_THROTTLED !== 'false'
19-
const PRO_MODEL = process.env.AI_PRO_MODEL ?? 'anthropic.claude-3-7-sonnet-20250219-v1:0'
20-
const NORMAL_MODEL = process.env.AI_NORMAL_MODEL ?? 'anthropic.claude-3-5-haiku-20241022-v1:0'
7+
8+
const BEDROCK_PRO_MODEL = 'anthropic.claude-3-7-sonnet-20250219-v1:0'
9+
const BEDROCK_NORMAL_MODEL = 'anthropic.claude-3-5-haiku-20241022-v1:0'
2110
const OPENAI_MODEL = 'gpt-4.1-2025-04-14'
2211

2312
export type ModelSuccess = {
@@ -46,14 +35,11 @@ export async function getModel(routingKey?: string, isLimited?: boolean): Promis
4635
const hasOpenAIKey = !!process.env.OPENAI_API_KEY
4736

4837
if (hasAwsCredentials) {
49-
// Select the Bedrock region based on the routing key
50-
const bedrockRegion: BedrockRegion = routingKey ? await selectBedrockRegion(routingKey) : 'us1'
51-
const bedrock = bedrockForRegion(bedrockRegion)
52-
const model = IS_THROTTLED || isLimited ? NORMAL_MODEL : PRO_MODEL
53-
const modelName = `${regionMap[bedrockRegion]}.${model}`
38+
const bedrockModel = IS_THROTTLED || isLimited ? BEDROCK_NORMAL_MODEL : BEDROCK_PRO_MODEL
39+
const bedrock = createRoutedBedrock(routingKey)
5440

5541
return {
56-
model: bedrock(modelName),
42+
model: await bedrock(bedrockModel),
5743
}
5844
}
5945

apps/studio/lib/ai/util.test.ts

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import { describe, it, expect } from 'vitest'
2+
import { selectWeightedKey } from './util'
3+
4+
describe('selectWeightedKey', () => {
5+
it('should return a valid key from the weights object', async () => {
6+
const weights = { a: 10, b: 20, c: 30 }
7+
const result = await selectWeightedKey('test-input', weights)
8+
9+
expect(Object.keys(weights)).toContain(result)
10+
})
11+
12+
it('should return consistent results for the same input', async () => {
13+
const weights = { region1: 40, region2: 10, region3: 20 }
14+
const input = 'consistent-key'
15+
16+
const result1 = await selectWeightedKey(input, weights)
17+
const result2 = await selectWeightedKey(input, weights)
18+
const result3 = await selectWeightedKey(input, weights)
19+
20+
expect(result1).toBe(result2)
21+
expect(result2).toBe(result3)
22+
})
23+
24+
it('should distribute keys according to weights', async () => {
25+
const weights = { a: 80, b: 10, c: 10 }
26+
const numSamples = 10000
27+
const samples = Array.from({ length: numSamples }, (_, i) => `sample-${i}`)
28+
29+
const results = await Promise.all(samples.map((sample) => selectWeightedKey(sample, weights)))
30+
31+
const counts = results.reduce<Record<string, number>>((acc, key) => {
32+
acc[key] = (acc[key] ?? 0) + 1
33+
return acc
34+
}, {})
35+
36+
expect(counts.a / numSamples).toBeCloseTo(0.8, 1)
37+
expect(counts.b / numSamples).toBeCloseTo(0.1, 1)
38+
expect(counts.c / numSamples).toBeCloseTo(0.1, 1)
39+
})
40+
41+
it('should handle equal weights', async () => {
42+
const weights = { x: 25, y: 25, z: 25, w: 25 }
43+
const numSamples = 8000
44+
const samples = Array.from({ length: numSamples }, (_, i) => `equal-${i}`)
45+
46+
const results = await Promise.all(samples.map((sample) => selectWeightedKey(sample, weights)))
47+
48+
const counts = results.reduce<Record<string, number>>((acc, key) => {
49+
acc[key] = (acc[key] ?? 0) + 1
50+
return acc
51+
}, {})
52+
53+
// Each key should get roughly 25% of the samples
54+
Object.values(counts).forEach((count) => {
55+
expect(count / numSamples).toBeCloseTo(0.25, 1)
56+
})
57+
})
58+
59+
it('should handle single key', async () => {
60+
const weights = { only: 100 }
61+
const result = await selectWeightedKey('any-input', weights)
62+
63+
expect(result).toBe('only')
64+
})
65+
66+
it('should handle empty string input', async () => {
67+
const weights = { a: 10, b: 20 }
68+
const result = await selectWeightedKey('', weights)
69+
70+
expect(Object.keys(weights)).toContain(result)
71+
})
72+
73+
it('should handle unicode characters in input', async () => {
74+
const weights = { option1: 50, option2: 50 }
75+
const unicodeInput = '🔑-unicode-key-测试'
76+
77+
const result1 = await selectWeightedKey(unicodeInput, weights)
78+
const result2 = await selectWeightedKey(unicodeInput, weights)
79+
80+
expect(result1).toBe(result2)
81+
expect(Object.keys(weights)).toContain(result1)
82+
})
83+
})

apps/studio/lib/ai/util.ts

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/**
2+
* Selects a key from weighted choices using consistent hashing
3+
* on an input string.
4+
*
5+
* The same input always returns the same key, with distribution
6+
* proportional to the provided weights.
7+
*
8+
* @example
9+
* const region = await selectWeightedKey('my-unique-id', {
10+
* use1: 40,
11+
* use2: 10,
12+
* usw2: 10,
13+
* euc1: 10,
14+
* })
15+
* // Returns one of the keys based on the input and weights
16+
*/
17+
export async function selectWeightedKey<T extends string>(
18+
input: string,
19+
weights: Record<T, number>
20+
): Promise<T> {
21+
const keys = Object.keys(weights) as T[]
22+
const encoder = new TextEncoder()
23+
const data = encoder.encode(input)
24+
const hashBuffer = await crypto.subtle.digest('SHA-256', data)
25+
26+
// Use first 4 bytes (32 bit integer)
27+
const hashInt = new DataView(hashBuffer).getUint32(0)
28+
29+
const totalWeight = keys.reduce((sum, key) => sum + weights[key], 0)
30+
31+
let cumulativeWeight = 0
32+
const targetWeight = hashInt % totalWeight
33+
34+
for (const key of keys) {
35+
cumulativeWeight += weights[key]
36+
if (cumulativeWeight > targetWeight) {
37+
return key
38+
}
39+
}
40+
41+
return keys[0]
42+
}

0 commit comments

Comments
 (0)