11import { isProd } from '../lib/env' ;
22import * as AWS from 'aws-sdk' ;
33import { buildReloader , ValueProvider } from '../utils/valueReloader' ;
4- import { BannerTest , Channel , EpicTest , Methodology , Test , Variant } from '../../shared/types' ;
4+ import {
5+ BanditMethodology ,
6+ BannerTest ,
7+ Channel ,
8+ EpicTest ,
9+ Test ,
10+ Variant ,
11+ } from '../../shared/types' ;
512import { z } from 'zod' ;
613import { logError } from '../utils/logging' ;
714import { putMetric } from '../utils/cloudwatch' ;
@@ -33,6 +40,7 @@ interface BanditTestConfig {
3340 testName : string ; // this may be specific to the methodology, e.g. MY_TEST_EpsilonGreedyBandit-0.5
3441 channel : Channel ;
3542 variantNames : string [ ] ;
43+ sampleCount ?: number ;
3644}
3745
3846// If sampleCount is not provided, all samples will be returned
@@ -52,8 +60,12 @@ function queryForTestSamples(testName: string, channel: Channel, sampleCount?: n
5260 . promise ( ) ;
5361}
5462
55- async function getBanditSamplesForTest ( testName : string , channel : Channel ) : Promise < TestSample [ ] > {
56- const queryResult = await queryForTestSamples ( testName , channel ) ;
63+ async function getBanditSamplesForTest (
64+ testName : string ,
65+ channel : Channel ,
66+ sampleCount ?: number ,
67+ ) : Promise < TestSample [ ] > {
68+ const queryResult = await queryForTestSamples ( testName , channel , sampleCount ) ;
5769
5870 const parsedResults = queryResultSchema . safeParse ( queryResult . Items ) ;
5971
@@ -134,7 +146,7 @@ async function buildBanditDataForTest(test: BanditTestConfig): Promise<BanditDat
134146 } ;
135147 }
136148
137- const samples = await getBanditSamplesForTest ( test . testName , test . channel ) ;
149+ const samples = await getBanditSamplesForTest ( test . testName , test . channel , test . sampleCount ) ;
138150
139151 if ( samples . length < MINIMUM_SAMPLES ) {
140152 return getDefaultWeighting ( test ) ;
@@ -152,13 +164,15 @@ async function buildBanditDataForTest(test: BanditTestConfig): Promise<BanditDat
152164
153165// Return config for each bandit methodology in this test
154166function getBanditTestConfigs < V extends Variant , T extends Test < V > > ( test : T ) : BanditTestConfig [ ] {
155- const bandits : Methodology [ ] = ( test . methodologies ?? [ ] ) . filter (
167+ const bandits : BanditMethodology [ ] = ( test . methodologies ?? [ ] ) . filter (
156168 ( method ) => method . name === 'EpsilonGreedyBandit' || method . name === 'Roulette' ,
157- ) ;
169+ ) as BanditMethodology [ ] ;
170+
158171 return bandits . map ( ( method ) => ( {
159172 testName : method . testName ?? test . name , // if the methodology should be tracked with a different name then use that
160173 channel : test . channel ,
161174 variantNames : test . variants . map ( ( v ) => v . name ) ,
175+ sampleCount : method . sampleCount ,
162176 } ) ) ;
163177}
164178
0 commit comments