Skip to content

Commit 4c136e6

Browse files
committed
fixes
1 parent c846df6 commit 4c136e6

File tree

4 files changed

+312
-38
lines changed

4 files changed

+312
-38
lines changed

test/test_mcts.py

Lines changed: 241 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import pytest
99
import torch
1010
from tensordict import TensorDict
11-
from torchrl.modules.mcts.scores import EXP3Score, PUCTScore, UCBScore
11+
from torchrl.modules.mcts.scores import EXP3Score, PUCTScore, UCB1TunedScore, UCBScore
12+
1213

1314
# Sample TensorDict for testing
1415
def create_node(
@@ -22,14 +23,15 @@ def create_node(
2223
}
2324

2425
if batch_size:
26+
# num_actions needs batch dimension to match TensorDict batch_size
2527
data = {
26-
custom_keys["num_actions_key"]: torch.tensor(
27-
[num_actions] * batch_size, device=device
28+
custom_keys["num_actions_key"]: torch.full(
29+
(batch_size,), num_actions, device=device, dtype=torch.long
2830
)
2931
}
3032
if weights is not None:
3133
if weights.ndim == 1:
32-
weights = weights.unsqueeze(0).repeat(batch_size, 1)
34+
weights = weights.unsqueeze(0).expand(batch_size, -1)
3335
data[custom_keys["weights_key"]] = weights.to(device)
3436
td = TensorDict(data, batch_size=[batch_size], device=device)
3537
else:
@@ -81,7 +83,7 @@ def create_ucb_node(
8183
}
8284
td = TensorDict(
8385
data,
84-
batch_size=[batch_s for batch_s in batch_size]
86+
batch_size=list(batch_size)
8587
if isinstance(batch_size, (list, tuple))
8688
else [batch_size],
8789
device=device,
@@ -352,7 +354,7 @@ def test_update_weights_batch(self, default_scorer, batch_s):
352354
action_idx = action_indices[i].item()
353355
reward = rewards[i].item()
354356

355-
single_node_td = node[i]
357+
node[i]
356358

357359
current_weight_item = initial_weights_batch[i, action_idx]
358360
prob_i_item = probs_batch[i, action_idx]
@@ -845,3 +847,236 @@ def test_custom_keys(self, puct_custom_key_names):
845847
assert "visits" not in node.keys()
846848
assert "total_visits" not in node.keys()
847849
assert "prior_prob" not in node.keys()
850+
851+
852+
# Helper function to create a sample TensorDict node for UCB1TunedScore
853+
def create_ucb1_tuned_node(
854+
win_count,
855+
visits,
856+
total_visits,
857+
sum_squared_rewards,
858+
batch_size=None,
859+
device="cpu",
860+
custom_keys=None,
861+
):
862+
if custom_keys is None:
863+
custom_keys = {
864+
"win_count_key": "win_count",
865+
"visits_key": "visits",
866+
"total_visits_key": "total_visits",
867+
"sum_squared_rewards_key": "sum_squared_rewards",
868+
"score_key": "score",
869+
}
870+
871+
win_count = torch.as_tensor(win_count, device=device, dtype=torch.float32)
872+
visits = torch.as_tensor(visits, device=device, dtype=torch.float32)
873+
total_visits = torch.as_tensor(total_visits, device=device, dtype=torch.float32)
874+
sum_squared_rewards = torch.as_tensor(
875+
sum_squared_rewards, device=device, dtype=torch.float32
876+
)
877+
878+
if batch_size:
879+
if win_count.ndim == 0:
880+
win_count = win_count.unsqueeze(0).repeat(batch_size)
881+
elif win_count.shape[0] != batch_size:
882+
raise ValueError("Batch size mismatch for win_count")
883+
if visits.ndim == 0:
884+
visits = visits.unsqueeze(0).repeat(batch_size)
885+
elif visits.shape[0] != batch_size:
886+
raise ValueError("Batch size mismatch for visits")
887+
if sum_squared_rewards.ndim == 0:
888+
sum_squared_rewards = sum_squared_rewards.unsqueeze(0).repeat(batch_size)
889+
elif sum_squared_rewards.shape[0] != batch_size:
890+
raise ValueError("Batch size mismatch for sum_squared_rewards")
891+
if total_visits.numel() == 1 and batch_size > 1:
892+
total_visits = total_visits.repeat(batch_size)
893+
elif total_visits.ndim == 0:
894+
total_visits = total_visits.unsqueeze(0).repeat(batch_size)
895+
elif total_visits.shape[0] != batch_size:
896+
raise ValueError("Batch size mismatch for total_visits")
897+
898+
data = {
899+
custom_keys["win_count_key"]: win_count,
900+
custom_keys["visits_key"]: visits,
901+
custom_keys["total_visits_key"]: total_visits,
902+
custom_keys["sum_squared_rewards_key"]: sum_squared_rewards,
903+
}
904+
if isinstance(batch_size, (list, tuple)):
905+
td_batch_size = batch_size
906+
else:
907+
td_batch_size = [batch_size]
908+
td = TensorDict(data, batch_size=td_batch_size, device=device)
909+
else:
910+
data = {
911+
custom_keys["win_count_key"]: win_count,
912+
custom_keys["visits_key"]: visits,
913+
custom_keys["total_visits_key"]: total_visits,
914+
custom_keys["sum_squared_rewards_key"]: sum_squared_rewards,
915+
}
916+
td_batch_size = win_count.shape[:-1] if win_count.ndim > 1 else []
917+
td = TensorDict(data, batch_size=td_batch_size, device=device)
918+
919+
return td
920+
921+
922+
class TestUCB1TunedScore:
923+
@pytest.fixture
924+
def default_ucb1_tuned_scorer(self):
925+
return UCB1TunedScore(exploration_constant=2.0)
926+
927+
@pytest.fixture
928+
def ucb1_tuned_custom_key_names(self):
929+
return {
930+
"win_count_key": "custom_wins",
931+
"visits_key": "custom_visits",
932+
"total_visits_key": "custom_total_visits",
933+
"sum_squared_rewards_key": "custom_sum_sq_rewards",
934+
"score_key": "custom_ucb1_tuned_score",
935+
}
936+
937+
@pytest.mark.parametrize("exploration_constant", [1.0, 2.0, 3.0])
938+
def test_initialization(self, exploration_constant):
939+
scorer = UCB1TunedScore(exploration_constant=exploration_constant)
940+
assert scorer.exploration_constant == exploration_constant
941+
942+
def test_forward_basic(self, default_ucb1_tuned_scorer):
943+
# Rewards in [0, 1] range for UCB1-Tuned
944+
win_count = torch.tensor([0.8, 0.6, 0.9]) # sum of rewards
945+
visits = torch.tensor([10.0, 5.0, 15.0])
946+
total_visits = torch.tensor(30.0)
947+
# sum_squared_rewards for rewards in [0,1]
948+
sum_squared_rewards = torch.tensor([0.7, 0.4, 0.85])
949+
950+
node = create_ucb1_tuned_node(
951+
win_count=win_count,
952+
visits=visits,
953+
total_visits=total_visits,
954+
sum_squared_rewards=sum_squared_rewards,
955+
)
956+
default_ucb1_tuned_scorer.forward(node)
957+
958+
scores = node.get(default_ucb1_tuned_scorer.score_key)
959+
assert scores.shape == win_count.shape
960+
# All visited actions should have finite scores
961+
assert torch.all(torch.isfinite(scores))
962+
963+
def test_forward_unvisited_actions(self, default_ucb1_tuned_scorer):
964+
win_count = torch.tensor([0.5, 0.0, 0.3])
965+
visits = torch.tensor([5.0, 0.0, 3.0]) # Second action unvisited
966+
total_visits = torch.tensor(8.0)
967+
sum_squared_rewards = torch.tensor([0.3, 0.0, 0.15])
968+
969+
node = create_ucb1_tuned_node(
970+
win_count=win_count,
971+
visits=visits,
972+
total_visits=total_visits,
973+
sum_squared_rewards=sum_squared_rewards,
974+
)
975+
default_ucb1_tuned_scorer.forward(node)
976+
977+
scores = node.get(default_ucb1_tuned_scorer.score_key)
978+
# Unvisited action should have a very large score
979+
assert scores[1] > scores[0]
980+
assert scores[1] > scores[2]
981+
# Should be close to max float / 10
982+
assert scores[1] > 1e30
983+
984+
@pytest.mark.parametrize("batch_s", [2, 3])
985+
def test_forward_batch(self, default_ucb1_tuned_scorer, batch_s):
986+
num_actions = 3
987+
win_count = torch.rand(batch_s, num_actions)
988+
visits = torch.rand(batch_s, num_actions) * 5 + 1 # Ensure visits > 0
989+
total_visits = torch.rand(batch_s) * 20 + float(batch_s)
990+
sum_squared_rewards = torch.rand(batch_s, num_actions)
991+
992+
node = create_ucb1_tuned_node(
993+
win_count=win_count,
994+
visits=visits,
995+
total_visits=total_visits,
996+
sum_squared_rewards=sum_squared_rewards,
997+
batch_size=batch_s,
998+
)
999+
default_ucb1_tuned_scorer.forward(node)
1000+
1001+
scores = node.get(default_ucb1_tuned_scorer.score_key)
1002+
assert scores.shape == (batch_s, num_actions)
1003+
# All should be finite since all visits > 0
1004+
assert torch.all(torch.isfinite(scores))
1005+
1006+
def test_forward_variance_clamping(self, default_ucb1_tuned_scorer):
1007+
# Test that min(0.25, V_i) is applied correctly
1008+
# High variance case
1009+
win_count = torch.tensor([5.0])
1010+
visits = torch.tensor([10.0])
1011+
total_visits = torch.tensor(100.0)
1012+
# Very high sum of squared rewards to create high variance
1013+
sum_squared_rewards = torch.tensor([10.0])
1014+
1015+
node = create_ucb1_tuned_node(
1016+
win_count=win_count,
1017+
visits=visits,
1018+
total_visits=total_visits,
1019+
sum_squared_rewards=sum_squared_rewards,
1020+
)
1021+
default_ucb1_tuned_scorer.forward(node)
1022+
1023+
scores = node.get(default_ucb1_tuned_scorer.score_key)
1024+
assert torch.all(torch.isfinite(scores))
1025+
1026+
def test_custom_keys(self, ucb1_tuned_custom_key_names):
1027+
scorer = UCB1TunedScore(
1028+
exploration_constant=2.0,
1029+
win_count_key=ucb1_tuned_custom_key_names["win_count_key"],
1030+
visits_key=ucb1_tuned_custom_key_names["visits_key"],
1031+
total_visits_key=ucb1_tuned_custom_key_names["total_visits_key"],
1032+
sum_squared_rewards_key=ucb1_tuned_custom_key_names[
1033+
"sum_squared_rewards_key"
1034+
],
1035+
score_key=ucb1_tuned_custom_key_names["score_key"],
1036+
)
1037+
1038+
win_count = torch.tensor([0.5, 0.3])
1039+
visits = torch.tensor([5.0, 3.0])
1040+
total_visits = torch.tensor(10.0)
1041+
sum_squared_rewards = torch.tensor([0.3, 0.15])
1042+
1043+
node = create_ucb1_tuned_node(
1044+
win_count=win_count,
1045+
visits=visits,
1046+
total_visits=total_visits,
1047+
sum_squared_rewards=sum_squared_rewards,
1048+
custom_keys=ucb1_tuned_custom_key_names,
1049+
)
1050+
scorer.forward(node)
1051+
1052+
assert ucb1_tuned_custom_key_names["score_key"] in node.keys()
1053+
scores = node.get(ucb1_tuned_custom_key_names["score_key"])
1054+
assert scores.shape == win_count.shape
1055+
assert torch.all(torch.isfinite(scores))
1056+
1057+
# Check that default keys are not present
1058+
assert "score" not in node.keys()
1059+
assert "win_count" not in node.keys()
1060+
assert "visits" not in node.keys()
1061+
assert "total_visits" not in node.keys()
1062+
assert "sum_squared_rewards" not in node.keys()
1063+
1064+
def test_exploration_vs_exploitation(self, default_ucb1_tuned_scorer):
1065+
# Action 0: high average reward, many visits (exploitation)
1066+
# Action 1: low average reward, few visits (exploration)
1067+
win_count = torch.tensor([9.0, 1.0])
1068+
visits = torch.tensor([10.0, 2.0])
1069+
total_visits = torch.tensor(12.0)
1070+
sum_squared_rewards = torch.tensor([8.5, 0.6])
1071+
1072+
node = create_ucb1_tuned_node(
1073+
win_count=win_count,
1074+
visits=visits,
1075+
total_visits=total_visits,
1076+
sum_squared_rewards=sum_squared_rewards,
1077+
)
1078+
default_ucb1_tuned_scorer.forward(node)
1079+
1080+
scores = node.get(default_ucb1_tuned_scorer.score_key)
1081+
# Both should be finite
1082+
assert torch.all(torch.isfinite(scores))

torchrl/modules/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,14 @@
9595
from .tensordict_module.exploration import RandomPolicy
9696
from .utils import get_primers_from_module
9797
from .planners import CEMPlanner, MPCPlannerBase, MPPIPlanner # usort:skip
98+
from .mcts import ( # usort:skip
99+
EXP3Score,
100+
MCTSScore,
101+
MCTSScores,
102+
PUCTScore,
103+
UCB1TunedScore,
104+
UCBScore,
105+
)
98106

99107
__all__ = [
100108
"Actor",
@@ -105,6 +113,7 @@
105113
"AdditiveGaussianWrapper",
106114
"BatchRenorm1d",
107115
"CEMPlanner",
116+
"EXP3Score",
108117
"ConsistentDropout",
109118
"ConsistentDropoutModule",
110119
"Conv3dNet",
@@ -134,6 +143,8 @@
134143
"LSTM",
135144
"LSTMCell",
136145
"LSTMModule",
146+
"MCTSScore",
147+
"MCTSScores",
137148
"MLP",
138149
"MPCPlannerBase",
139150
"MPPIPlanner",
@@ -157,6 +168,7 @@
157168
"OrnsteinUhlenbeckProcessModule",
158169
"OrnsteinUhlenbeckProcessWrapper",
159170
"ProbabilisticActor",
171+
"PUCTScore",
160172
"QMixer",
161173
"QValueActor",
162174
"QValueHook",
@@ -175,6 +187,8 @@
175187
"TanhModule",
176188
"TanhNormal",
177189
"TruncatedNormal",
190+
"UCB1TunedScore",
191+
"UCBScore",
178192
"VDNMixer",
179193
"ValueOperator",
180194
"VmapModule",

torchrl/modules/mcts/__init__.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,20 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5-
from .scores import EXP3Score, PUCTScore, UCBScore
5+
from .scores import (
6+
EXP3Score,
7+
MCTSScore,
8+
MCTSScores,
9+
PUCTScore,
10+
UCB1TunedScore,
11+
UCBScore,
12+
)
13+
14+
__all__ = [
15+
"EXP3Score",
16+
"MCTSScore",
17+
"MCTSScores",
18+
"PUCTScore",
19+
"UCB1TunedScore",
20+
"UCBScore",
21+
]

0 commit comments

Comments
 (0)