Skip to content

Commit ebac8c5

Browse files
authored
Merge pull request #1260 from guardian/tf-sample-count
Make bandit sample count configurable
2 parents 50bae30 + 9e93bdd commit ebac8c5

File tree

2 files changed

+26
-7
lines changed

2 files changed

+26
-7
lines changed

src/server/bandit/banditData.ts

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
import { isProd } from '../lib/env';
22
import * as AWS from 'aws-sdk';
33
import { buildReloader, ValueProvider } from '../utils/valueReloader';
4-
import { BannerTest, Channel, EpicTest, Methodology, Test, Variant } from '../../shared/types';
4+
import {
5+
BanditMethodology,
6+
BannerTest,
7+
Channel,
8+
EpicTest,
9+
Test,
10+
Variant,
11+
} from '../../shared/types';
512
import { z } from 'zod';
613
import { logError } from '../utils/logging';
714
import { putMetric } from '../utils/cloudwatch';
@@ -33,6 +40,7 @@ interface BanditTestConfig {
3340
testName: string; // this may be specific to the methodology, e.g. MY_TEST_EpsilonGreedyBandit-0.5
3441
channel: Channel;
3542
variantNames: string[];
43+
sampleCount?: number;
3644
}
3745

3846
// If sampleCount is not provided, all samples will be returned
@@ -52,8 +60,12 @@ function queryForTestSamples(testName: string, channel: Channel, sampleCount?: n
5260
.promise();
5361
}
5462

55-
async function getBanditSamplesForTest(testName: string, channel: Channel): Promise<TestSample[]> {
56-
const queryResult = await queryForTestSamples(testName, channel);
63+
async function getBanditSamplesForTest(
64+
testName: string,
65+
channel: Channel,
66+
sampleCount?: number,
67+
): Promise<TestSample[]> {
68+
const queryResult = await queryForTestSamples(testName, channel, sampleCount);
5769

5870
const parsedResults = queryResultSchema.safeParse(queryResult.Items);
5971

@@ -134,7 +146,7 @@ async function buildBanditDataForTest(test: BanditTestConfig): Promise<BanditDat
134146
};
135147
}
136148

137-
const samples = await getBanditSamplesForTest(test.testName, test.channel);
149+
const samples = await getBanditSamplesForTest(test.testName, test.channel, test.sampleCount);
138150

139151
if (samples.length < MINIMUM_SAMPLES) {
140152
return getDefaultWeighting(test);
@@ -152,13 +164,15 @@ async function buildBanditDataForTest(test: BanditTestConfig): Promise<BanditDat
152164

153165
// Return config for each bandit methodology in this test
154166
function getBanditTestConfigs<V extends Variant, T extends Test<V>>(test: T): BanditTestConfig[] {
155-
const bandits: Methodology[] = (test.methodologies ?? []).filter(
167+
const bandits: BanditMethodology[] = (test.methodologies ?? []).filter(
156168
(method) => method.name === 'EpsilonGreedyBandit' || method.name === 'Roulette',
157-
);
169+
) as BanditMethodology[];
170+
158171
return bandits.map((method) => ({
159172
testName: method.testName ?? test.name, // if the methodology should be tracked with a different name then use that
160173
channel: test.channel,
161174
variantNames: test.variants.map((v) => v.name),
175+
sampleCount: method.sampleCount,
162176
}));
163177
}
164178

src/shared/types/abTests/shared.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,12 @@ const abTestMethodologySchema = z.object({ name: z.literal('ABTest') });
4343
const epsilonGreedyMethodologySchema = z.object({
4444
name: z.literal('EpsilonGreedyBandit'),
4545
epsilon: z.number(),
46+
sampleCount: z.number().optional(),
47+
});
48+
const rouletteMethodologySchema = z.object({
49+
name: z.literal('Roulette'),
50+
sampleCount: z.number().optional(),
4651
});
47-
const rouletteMethodologySchema = z.object({ name: z.literal('Roulette') });
4852
const methodologySchema = z.intersection(
4953
z.discriminatedUnion('name', [
5054
abTestMethodologySchema,
@@ -55,6 +59,7 @@ const methodologySchema = z.intersection(
5559
z.object({ testName: z.string().optional() }),
5660
);
5761
export type Methodology = z.infer<typeof methodologySchema>;
62+
export type BanditMethodology = Exclude<Methodology, { name: 'ABTest' }>;
5863

5964
export interface Variant {
6065
name: string;

0 commit comments

Comments
 (0)