Skip to content

Commit a7cccdc

Browse files
committed
minimum weight for roulette sampling
1 parent 24a2719 commit a7cccdc

File tree

2 files changed

+56
-4
lines changed

2 files changed

+56
-4
lines changed

src/server/roulette/rouletteSelection.test.ts

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,48 @@ describe('roulette', () => {
140140
const variant = selectVariantUsingRoulette([banditData], epicTest, rand);
141141
expect(variant).toBeDefined();
142142
});
143+
144+
it('should ensure a minimum of 10% for variants with mean of 0', () => {
145+
const variants = [
146+
{
147+
variantName: 'v1',
148+
mean: 2,
149+
},
150+
{
151+
variantName: 'v2',
152+
mean: 0,
153+
},
154+
{
155+
variantName: 'v3',
156+
mean: 0,
157+
},
158+
];
159+
const banditData = {
160+
testName: 'example-1',
161+
bestVariants: variants,
162+
variants: variants,
163+
};
164+
165+
/**
166+
* variantsWithWeights: [
167+
* { variantName: 'v2', weight: 0.1 },
168+
* { variantName: 'v3', weight: 0.1 },
169+
* { variantName: 'v1', weight: 1 }
170+
* ]
171+
*
172+
* normalisedWeights: [
173+
* { variantName: 'v2', weight: 0.08333333333333334 },
174+
* { variantName: 'v3', weight: 0.08333333333333334 },
175+
* { variantName: 'v1', weight: 0.8333333333333334 }
176+
* ]
177+
*/
178+
const variantSelection1 = selectVariantUsingRoulette([banditData], epicTest, 0.08);
179+
const variantSelection2 = selectVariantUsingRoulette([banditData], epicTest, 0.16);
180+
const variantSelection3 = selectVariantUsingRoulette([banditData], epicTest, 0.2);
181+
expect(variantSelection1).toBe(epicTest.variants[1]);
182+
expect(variantSelection2).toBe(epicTest.variants[2]);
183+
expect(variantSelection3).toBe(epicTest.variants[0]);
184+
});
143185
});
144186

145187
describe('rouletteTest2', () => {

src/server/roulette/rouletteSelection.ts

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,23 @@ export function selectVariantUsingRoulette<V extends Variant, T extends Test<V>>
1818
return selectRandomVariant(test);
1919
}
2020

21-
// sorted variant weights, which will add up to 1
21+
const minWeight = 0.1; // Ensure no variant gets less than 10%
2222
const variantsWithWeights: { weight: number; variantName: string }[] = testBanditData.variants
23-
.map(({ variantName, mean }) => ({ variantName, weight: mean / sumOfMeans }))
23+
.map(({ variantName, mean }) => ({
24+
variantName,
25+
weight: Math.max(mean / sumOfMeans, minWeight),
26+
}))
2427
.sort((a, b) => a.weight - b.weight);
2528

26-
for (let i = 0, acc = 0; i < variantsWithWeights.length; i++) {
27-
const variant = variantsWithWeights[i];
29+
// The sum of the weights may be greater than 1, so we now need to normalise them
30+
const sumOfWeights = variantsWithWeights.reduce((sum, v) => sum + v.weight, 0);
31+
const normalisedWeights = variantsWithWeights.map(({ variantName, weight }) => ({
32+
variantName,
33+
weight: weight / sumOfWeights,
34+
}));
35+
36+
for (let i = 0, acc = 0; i < normalisedWeights.length; i++) {
37+
const variant = normalisedWeights[i];
2838
if (rand < variant.weight + acc) {
2939
return test.variants.find((v) => v.name === variant.variantName);
3040
}

0 commit comments

Comments
 (0)