@@ -425,4 +425,94 @@ describe('BanditEvaluator', () => {
425
425
expect ( resultB . optimalityGap ) . toBeCloseTo ( 0.3 ) ;
426
426
} ) ;
427
427
} ) ;
428
+
429
+ describe ( 'evaluateBestBandit' , ( ) => {
430
+ it ( 'evaluates the bandit with action contexts' , ( ) => {
431
+ const subjectAttributes : ContextAttributes = {
432
+ numericAttributes : { age : 25 } ,
433
+ categoricalAttributes : { location : 'US' } ,
434
+ } ;
435
+ const subjectAttributesB : ContextAttributes = {
436
+ numericAttributes : { age : 25 } ,
437
+ categoricalAttributes : { } ,
438
+ } ;
439
+ const actions : Record < string , ContextAttributes > = {
440
+ action1 : { numericAttributes : { price : 10 } , categoricalAttributes : { category : 'A' } } ,
441
+ action2 : { numericAttributes : { price : 20 } , categoricalAttributes : { category : 'B' } } ,
442
+ } ;
443
+ const banditModel : BanditModelData = {
444
+ gamma : 0.1 ,
445
+ defaultActionScore : 0.0 ,
446
+ actionProbabilityFloor : 0.1 ,
447
+ coefficients : {
448
+ action1 : {
449
+ actionKey : 'action1' ,
450
+ intercept : 0.5 ,
451
+ subjectNumericCoefficients : [
452
+ { attributeKey : 'age' , coefficient : 0.1 , missingValueCoefficient : 0.0 } ,
453
+ ] ,
454
+ subjectCategoricalCoefficients : [
455
+ {
456
+ attributeKey : 'location' ,
457
+ missingValueCoefficient : 0.0 ,
458
+ valueCoefficients : { US : 0.2 } ,
459
+ } ,
460
+ ] ,
461
+ actionNumericCoefficients : [
462
+ { attributeKey : 'price' , coefficient : 0.05 , missingValueCoefficient : 0.0 } ,
463
+ ] ,
464
+ actionCategoricalCoefficients : [
465
+ {
466
+ attributeKey : 'category' ,
467
+ missingValueCoefficient : 0.0 ,
468
+ valueCoefficients : { A : 0.3 } ,
469
+ } ,
470
+ ] ,
471
+ } ,
472
+ action2 : {
473
+ actionKey : 'action2' ,
474
+ intercept : 0.3 ,
475
+ subjectNumericCoefficients : [
476
+ { attributeKey : 'age' , coefficient : 0.1 , missingValueCoefficient : 0.0 } ,
477
+ ] ,
478
+ subjectCategoricalCoefficients : [
479
+ {
480
+ attributeKey : 'location' ,
481
+ missingValueCoefficient : - 3.0 ,
482
+ valueCoefficients : { US : 0.2 } ,
483
+ } ,
484
+ ] ,
485
+ actionNumericCoefficients : [
486
+ { attributeKey : 'price' , coefficient : 0.05 , missingValueCoefficient : 0.0 } ,
487
+ ] ,
488
+ actionCategoricalCoefficients : [
489
+ {
490
+ attributeKey : 'category' ,
491
+ missingValueCoefficient : 0.0 ,
492
+ valueCoefficients : { B : 0.3 } ,
493
+ } ,
494
+ ] ,
495
+ } ,
496
+ } ,
497
+ } ;
498
+
499
+ // Subject A gets assigned action 2
500
+ const resultA = banditEvaluator . evaluateBestBanditAction (
501
+ subjectAttributes ,
502
+ actions ,
503
+ banditModel ,
504
+ ) ;
505
+
506
+ expect ( resultA ) . toEqual ( 'action2' ) ;
507
+
508
+ // Subject B gets assigned action 1 because of the missing location penalty
509
+ const resultB = banditEvaluator . evaluateBestBanditAction (
510
+ subjectAttributesB ,
511
+ actions ,
512
+ banditModel ,
513
+ ) ;
514
+
515
+ expect ( resultB ) . toEqual ( 'action1' ) ;
516
+ } ) ;
517
+ } ) ;
428
518
} ) ;
0 commit comments