88import pytest
99import torch
1010from tensordict import TensorDict
11- from torchrl .modules .mcts .scores import EXP3Score , PUCTScore , UCBScore
11+ from torchrl .modules .mcts .scores import EXP3Score , PUCTScore , UCB1TunedScore , UCBScore
12+
1213
1314# Sample TensorDict for testing
1415def create_node (
@@ -22,14 +23,15 @@ def create_node(
2223 }
2324
2425 if batch_size :
26+ # num_actions needs batch dimension to match TensorDict batch_size
2527 data = {
26- custom_keys ["num_actions_key" ]: torch .tensor (
27- [ num_actions ] * batch_size , device = device
28+ custom_keys ["num_actions_key" ]: torch .full (
29+ ( batch_size ,), num_actions , device = device , dtype = torch . long
2830 )
2931 }
3032 if weights is not None :
3133 if weights .ndim == 1 :
32- weights = weights .unsqueeze (0 ).repeat (batch_size , 1 )
34+ weights = weights .unsqueeze (0 ).expand (batch_size , - 1 )
3335 data [custom_keys ["weights_key" ]] = weights .to (device )
3436 td = TensorDict (data , batch_size = [batch_size ], device = device )
3537 else :
@@ -81,7 +83,7 @@ def create_ucb_node(
8183 }
8284 td = TensorDict (
8385 data ,
84- batch_size = [ batch_s for batch_s in batch_size ]
86+ batch_size = list ( batch_size )
8587 if isinstance (batch_size , (list , tuple ))
8688 else [batch_size ],
8789 device = device ,
@@ -352,7 +354,7 @@ def test_update_weights_batch(self, default_scorer, batch_s):
352354 action_idx = action_indices [i ].item ()
353355 reward = rewards [i ].item ()
354356
355- single_node_td = node [i ]
357+ node [i ]
356358
357359 current_weight_item = initial_weights_batch [i , action_idx ]
358360 prob_i_item = probs_batch [i , action_idx ]
@@ -845,3 +847,236 @@ def test_custom_keys(self, puct_custom_key_names):
845847 assert "visits" not in node .keys ()
846848 assert "total_visits" not in node .keys ()
847849 assert "prior_prob" not in node .keys ()
850+
851+
852+ # Helper function to create a sample TensorDict node for UCB1TunedScore
853+ def create_ucb1_tuned_node (
854+ win_count ,
855+ visits ,
856+ total_visits ,
857+ sum_squared_rewards ,
858+ batch_size = None ,
859+ device = "cpu" ,
860+ custom_keys = None ,
861+ ):
862+ if custom_keys is None :
863+ custom_keys = {
864+ "win_count_key" : "win_count" ,
865+ "visits_key" : "visits" ,
866+ "total_visits_key" : "total_visits" ,
867+ "sum_squared_rewards_key" : "sum_squared_rewards" ,
868+ "score_key" : "score" ,
869+ }
870+
871+ win_count = torch .as_tensor (win_count , device = device , dtype = torch .float32 )
872+ visits = torch .as_tensor (visits , device = device , dtype = torch .float32 )
873+ total_visits = torch .as_tensor (total_visits , device = device , dtype = torch .float32 )
874+ sum_squared_rewards = torch .as_tensor (
875+ sum_squared_rewards , device = device , dtype = torch .float32
876+ )
877+
878+ if batch_size :
879+ if win_count .ndim == 0 :
880+ win_count = win_count .unsqueeze (0 ).repeat (batch_size )
881+ elif win_count .shape [0 ] != batch_size :
882+ raise ValueError ("Batch size mismatch for win_count" )
883+ if visits .ndim == 0 :
884+ visits = visits .unsqueeze (0 ).repeat (batch_size )
885+ elif visits .shape [0 ] != batch_size :
886+ raise ValueError ("Batch size mismatch for visits" )
887+ if sum_squared_rewards .ndim == 0 :
888+ sum_squared_rewards = sum_squared_rewards .unsqueeze (0 ).repeat (batch_size )
889+ elif sum_squared_rewards .shape [0 ] != batch_size :
890+ raise ValueError ("Batch size mismatch for sum_squared_rewards" )
891+ if total_visits .numel () == 1 and batch_size > 1 :
892+ total_visits = total_visits .repeat (batch_size )
893+ elif total_visits .ndim == 0 :
894+ total_visits = total_visits .unsqueeze (0 ).repeat (batch_size )
895+ elif total_visits .shape [0 ] != batch_size :
896+ raise ValueError ("Batch size mismatch for total_visits" )
897+
898+ data = {
899+ custom_keys ["win_count_key" ]: win_count ,
900+ custom_keys ["visits_key" ]: visits ,
901+ custom_keys ["total_visits_key" ]: total_visits ,
902+ custom_keys ["sum_squared_rewards_key" ]: sum_squared_rewards ,
903+ }
904+ if isinstance (batch_size , (list , tuple )):
905+ td_batch_size = batch_size
906+ else :
907+ td_batch_size = [batch_size ]
908+ td = TensorDict (data , batch_size = td_batch_size , device = device )
909+ else :
910+ data = {
911+ custom_keys ["win_count_key" ]: win_count ,
912+ custom_keys ["visits_key" ]: visits ,
913+ custom_keys ["total_visits_key" ]: total_visits ,
914+ custom_keys ["sum_squared_rewards_key" ]: sum_squared_rewards ,
915+ }
916+ td_batch_size = win_count .shape [:- 1 ] if win_count .ndim > 1 else []
917+ td = TensorDict (data , batch_size = td_batch_size , device = device )
918+
919+ return td
920+
921+
922+ class TestUCB1TunedScore :
923+ @pytest .fixture
924+ def default_ucb1_tuned_scorer (self ):
925+ return UCB1TunedScore (exploration_constant = 2.0 )
926+
927+ @pytest .fixture
928+ def ucb1_tuned_custom_key_names (self ):
929+ return {
930+ "win_count_key" : "custom_wins" ,
931+ "visits_key" : "custom_visits" ,
932+ "total_visits_key" : "custom_total_visits" ,
933+ "sum_squared_rewards_key" : "custom_sum_sq_rewards" ,
934+ "score_key" : "custom_ucb1_tuned_score" ,
935+ }
936+
937+ @pytest .mark .parametrize ("exploration_constant" , [1.0 , 2.0 , 3.0 ])
938+ def test_initialization (self , exploration_constant ):
939+ scorer = UCB1TunedScore (exploration_constant = exploration_constant )
940+ assert scorer .exploration_constant == exploration_constant
941+
942+ def test_forward_basic (self , default_ucb1_tuned_scorer ):
943+ # Rewards in [0, 1] range for UCB1-Tuned
944+ win_count = torch .tensor ([0.8 , 0.6 , 0.9 ]) # sum of rewards
945+ visits = torch .tensor ([10.0 , 5.0 , 15.0 ])
946+ total_visits = torch .tensor (30.0 )
947+ # sum_squared_rewards for rewards in [0,1]
948+ sum_squared_rewards = torch .tensor ([0.7 , 0.4 , 0.85 ])
949+
950+ node = create_ucb1_tuned_node (
951+ win_count = win_count ,
952+ visits = visits ,
953+ total_visits = total_visits ,
954+ sum_squared_rewards = sum_squared_rewards ,
955+ )
956+ default_ucb1_tuned_scorer .forward (node )
957+
958+ scores = node .get (default_ucb1_tuned_scorer .score_key )
959+ assert scores .shape == win_count .shape
960+ # All visited actions should have finite scores
961+ assert torch .all (torch .isfinite (scores ))
962+
963+ def test_forward_unvisited_actions (self , default_ucb1_tuned_scorer ):
964+ win_count = torch .tensor ([0.5 , 0.0 , 0.3 ])
965+ visits = torch .tensor ([5.0 , 0.0 , 3.0 ]) # Second action unvisited
966+ total_visits = torch .tensor (8.0 )
967+ sum_squared_rewards = torch .tensor ([0.3 , 0.0 , 0.15 ])
968+
969+ node = create_ucb1_tuned_node (
970+ win_count = win_count ,
971+ visits = visits ,
972+ total_visits = total_visits ,
973+ sum_squared_rewards = sum_squared_rewards ,
974+ )
975+ default_ucb1_tuned_scorer .forward (node )
976+
977+ scores = node .get (default_ucb1_tuned_scorer .score_key )
978+ # Unvisited action should have a very large score
979+ assert scores [1 ] > scores [0 ]
980+ assert scores [1 ] > scores [2 ]
981+ # Should be close to max float / 10
982+ assert scores [1 ] > 1e30
983+
984+ @pytest .mark .parametrize ("batch_s" , [2 , 3 ])
985+ def test_forward_batch (self , default_ucb1_tuned_scorer , batch_s ):
986+ num_actions = 3
987+ win_count = torch .rand (batch_s , num_actions )
988+ visits = torch .rand (batch_s , num_actions ) * 5 + 1 # Ensure visits > 0
989+ total_visits = torch .rand (batch_s ) * 20 + float (batch_s )
990+ sum_squared_rewards = torch .rand (batch_s , num_actions )
991+
992+ node = create_ucb1_tuned_node (
993+ win_count = win_count ,
994+ visits = visits ,
995+ total_visits = total_visits ,
996+ sum_squared_rewards = sum_squared_rewards ,
997+ batch_size = batch_s ,
998+ )
999+ default_ucb1_tuned_scorer .forward (node )
1000+
1001+ scores = node .get (default_ucb1_tuned_scorer .score_key )
1002+ assert scores .shape == (batch_s , num_actions )
1003+ # All should be finite since all visits > 0
1004+ assert torch .all (torch .isfinite (scores ))
1005+
1006+ def test_forward_variance_clamping (self , default_ucb1_tuned_scorer ):
1007+ # Test that min(0.25, V_i) is applied correctly
1008+ # High variance case
1009+ win_count = torch .tensor ([5.0 ])
1010+ visits = torch .tensor ([10.0 ])
1011+ total_visits = torch .tensor (100.0 )
1012+ # Very high sum of squared rewards to create high variance
1013+ sum_squared_rewards = torch .tensor ([10.0 ])
1014+
1015+ node = create_ucb1_tuned_node (
1016+ win_count = win_count ,
1017+ visits = visits ,
1018+ total_visits = total_visits ,
1019+ sum_squared_rewards = sum_squared_rewards ,
1020+ )
1021+ default_ucb1_tuned_scorer .forward (node )
1022+
1023+ scores = node .get (default_ucb1_tuned_scorer .score_key )
1024+ assert torch .all (torch .isfinite (scores ))
1025+
1026+ def test_custom_keys (self , ucb1_tuned_custom_key_names ):
1027+ scorer = UCB1TunedScore (
1028+ exploration_constant = 2.0 ,
1029+ win_count_key = ucb1_tuned_custom_key_names ["win_count_key" ],
1030+ visits_key = ucb1_tuned_custom_key_names ["visits_key" ],
1031+ total_visits_key = ucb1_tuned_custom_key_names ["total_visits_key" ],
1032+ sum_squared_rewards_key = ucb1_tuned_custom_key_names [
1033+ "sum_squared_rewards_key"
1034+ ],
1035+ score_key = ucb1_tuned_custom_key_names ["score_key" ],
1036+ )
1037+
1038+ win_count = torch .tensor ([0.5 , 0.3 ])
1039+ visits = torch .tensor ([5.0 , 3.0 ])
1040+ total_visits = torch .tensor (10.0 )
1041+ sum_squared_rewards = torch .tensor ([0.3 , 0.15 ])
1042+
1043+ node = create_ucb1_tuned_node (
1044+ win_count = win_count ,
1045+ visits = visits ,
1046+ total_visits = total_visits ,
1047+ sum_squared_rewards = sum_squared_rewards ,
1048+ custom_keys = ucb1_tuned_custom_key_names ,
1049+ )
1050+ scorer .forward (node )
1051+
1052+ assert ucb1_tuned_custom_key_names ["score_key" ] in node .keys ()
1053+ scores = node .get (ucb1_tuned_custom_key_names ["score_key" ])
1054+ assert scores .shape == win_count .shape
1055+ assert torch .all (torch .isfinite (scores ))
1056+
1057+ # Check that default keys are not present
1058+ assert "score" not in node .keys ()
1059+ assert "win_count" not in node .keys ()
1060+ assert "visits" not in node .keys ()
1061+ assert "total_visits" not in node .keys ()
1062+ assert "sum_squared_rewards" not in node .keys ()
1063+
1064+ def test_exploration_vs_exploitation (self , default_ucb1_tuned_scorer ):
1065+ # Action 0: high average reward, many visits (exploitation)
1066+ # Action 1: low average reward, few visits (exploration)
1067+ win_count = torch .tensor ([9.0 , 1.0 ])
1068+ visits = torch .tensor ([10.0 , 2.0 ])
1069+ total_visits = torch .tensor (12.0 )
1070+ sum_squared_rewards = torch .tensor ([8.5 , 0.6 ])
1071+
1072+ node = create_ucb1_tuned_node (
1073+ win_count = win_count ,
1074+ visits = visits ,
1075+ total_visits = total_visits ,
1076+ sum_squared_rewards = sum_squared_rewards ,
1077+ )
1078+ default_ucb1_tuned_scorer .forward (node )
1079+
1080+ scores = node .get (default_ucb1_tuned_scorer .score_key )
1081+ # Both should be finite
1082+ assert torch .all (torch .isfinite (scores ))
0 commit comments