@@ -425,4 +425,94 @@ describe('BanditEvaluator', () => {
425425 expect ( resultB . optimalityGap ) . toBeCloseTo ( 0.3 ) ;
426426 } ) ;
427427 } ) ;
428+
429+ describe ( 'evaluateBestBandit' , ( ) => {
430+ it ( 'evaluates the bandit with action contexts' , ( ) => {
431+ const subjectAttributes : ContextAttributes = {
432+ numericAttributes : { age : 25 } ,
433+ categoricalAttributes : { location : 'US' } ,
434+ } ;
435+ const subjectAttributesB : ContextAttributes = {
436+ numericAttributes : { age : 25 } ,
437+ categoricalAttributes : { } ,
438+ } ;
439+ const actions : Record < string , ContextAttributes > = {
440+ action1 : { numericAttributes : { price : 10 } , categoricalAttributes : { category : 'A' } } ,
441+ action2 : { numericAttributes : { price : 20 } , categoricalAttributes : { category : 'B' } } ,
442+ } ;
443+ const banditModel : BanditModelData = {
444+ gamma : 0.1 ,
445+ defaultActionScore : 0.0 ,
446+ actionProbabilityFloor : 0.1 ,
447+ coefficients : {
448+ action1 : {
449+ actionKey : 'action1' ,
450+ intercept : 0.5 ,
451+ subjectNumericCoefficients : [
452+ { attributeKey : 'age' , coefficient : 0.1 , missingValueCoefficient : 0.0 } ,
453+ ] ,
454+ subjectCategoricalCoefficients : [
455+ {
456+ attributeKey : 'location' ,
457+ missingValueCoefficient : 0.0 ,
458+ valueCoefficients : { US : 0.2 } ,
459+ } ,
460+ ] ,
461+ actionNumericCoefficients : [
462+ { attributeKey : 'price' , coefficient : 0.05 , missingValueCoefficient : 0.0 } ,
463+ ] ,
464+ actionCategoricalCoefficients : [
465+ {
466+ attributeKey : 'category' ,
467+ missingValueCoefficient : 0.0 ,
468+ valueCoefficients : { A : 0.3 } ,
469+ } ,
470+ ] ,
471+ } ,
472+ action2 : {
473+ actionKey : 'action2' ,
474+ intercept : 0.3 ,
475+ subjectNumericCoefficients : [
476+ { attributeKey : 'age' , coefficient : 0.1 , missingValueCoefficient : 0.0 } ,
477+ ] ,
478+ subjectCategoricalCoefficients : [
479+ {
480+ attributeKey : 'location' ,
481+ missingValueCoefficient : - 3.0 ,
482+ valueCoefficients : { US : 0.2 } ,
483+ } ,
484+ ] ,
485+ actionNumericCoefficients : [
486+ { attributeKey : 'price' , coefficient : 0.05 , missingValueCoefficient : 0.0 } ,
487+ ] ,
488+ actionCategoricalCoefficients : [
489+ {
490+ attributeKey : 'category' ,
491+ missingValueCoefficient : 0.0 ,
492+ valueCoefficients : { B : 0.3 } ,
493+ } ,
494+ ] ,
495+ } ,
496+ } ,
497+ } ;
498+
499+ // Subject A gets assigned action 2
500+ const resultA = banditEvaluator . evaluateBestBanditAction (
501+ subjectAttributes ,
502+ actions ,
503+ banditModel ,
504+ ) ;
505+
506+ expect ( resultA ) . toEqual ( 'action2' ) ;
507+
508+ // Subject B gets assigned action 1 because of the missing location penalty
509+ const resultB = banditEvaluator . evaluateBestBanditAction (
510+ subjectAttributesB ,
511+ actions ,
512+ banditModel ,
513+ ) ;
514+
515+ expect ( resultB ) . toEqual ( 'action1' ) ;
516+ } ) ;
517+ } ) ;
428518} ) ;
0 commit comments