[Feature] Added MCTSPolicyBase, MCTSPolicy, AlphaGoPolicy, AlphaStarPolicy, and MuZeroPolicy#3449
[Feature] Added MCTSPolicyBase, MCTSPolicy, AlphaGoPolicy, AlphaStarPolicy, and MuZeroPolicy#3449ParamThakkar123 wants to merge 19 commits intopytorch:mainfrom
Conversation
This reverts commit 1f6f327.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3449
Note: Links to docs will display an error until the docs builds have been completed. ❌ 5 New Failures, 1 Cancelled JobAs of commit 77868ed with merge base ab49b59 ( NEW FAILURES - The following jobs have failed:
CANCELLED JOB - The following job was cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
| Prefix | Label Applied | Example |
|---|---|---|
[BugFix] |
BugFix | [BugFix] Fix memory leak in collector |
[Feature] |
Feature | [Feature] Add new optimizer |
[Doc] or [Docs] |
Documentation | [Doc] Update installation guide |
[Refactor] |
Refactoring | [Refactor] Clean up module imports |
[CI] |
CI | [CI] Fix workflow permissions |
[Test] or [Tests] |
Tests | [Tests] Add unit tests for buffer |
[Environment] or [Environments] |
Environments | [Environments] Add Gymnasium support |
[Data] |
Data | [Data] Fix replay buffer sampling |
[Performance] or [Perf] |
Performance | [Performance] Optimize tensor ops |
[BC-Breaking] |
bc breaking | [BC-Breaking] Remove deprecated API |
[Deprecation] |
Deprecation | [Deprecation] Mark old function |
Note: Common variations like singular/plural are supported (e.g., [Doc] or [Docs]).
|
1 test is currently failing. Working on fixing that |
|
All tests passed locally |
| class AlphaStarPolicy(AlphaGoPolicy): | ||
| """AlphaStar-style MCTS policy with a lower exploration constant. | ||
|
|
||
| This policy is similar to AlphaGo but uses a smaller exploration constant (c=1.0) for potentially | ||
| more exploitative behavior. | ||
|
|
||
| Args: | ||
| c (float, optional): Exploration constant. Defaults to 1.0. | ||
| **kwargs: Additional keyword arguments passed to AlphaGoPolicy. | ||
| """ | ||
| def __init__(self, *, c: float = 1.0, **kwargs) -> None: | ||
| super().__init__(c=c, **kwargs) | ||
|
|
||
| class MuZeroPolicy(AlphaGoPolicy): | ||
| """MuZero-style MCTS policy with a specific exploration constant. | ||
|
|
||
| This policy implements the selection from MuZero, using PUCT with c=1.25. | ||
|
|
||
| Args: | ||
| c (float, optional): Exploration constant. Defaults to 1.25. | ||
| **kwargs: Additional keyword arguments passed to AlphaGoPolicy. | ||
| """ | ||
| def __init__(self, *, c: float = 1.25, **kwargs) -> None: | ||
| super().__init__(c=c, **kwargs) |
There was a problem hiding this comment.
This is a 3-class hierarchy to set a single float.
Let's consider whether class methods or constants would be more appropriate (e.g., MCTSPolicy.from_alphago(), MCTSPolicy.from_muzero())?
Additionally, the naming is somewhat misleading -- AlphaStar and MuZero use substantially different MCTS variants beyond just exploration constant tuning (e.g., MuZero uses a learned model for tree expansion). I guess users might expect these classes to encapsulate those algorithmic differences.
There was a problem hiding this comment.
Oh got it @vmoens I will try to refine these implementations that way
There was a problem hiding this comment.
Only 3 test methods for the entire new feature:
- No batched input tests -- the policy classes explicitly handle broadcasting and multi-dim masks, but none of this is tested.
- No edge case tests -- e.g., all-False mask (should raise ValueError per the code), different score modules (UCB, EXP3).
- No test for MCTSPolicyBase abstractness -- verify it can't be instantiated directly.
- No test for custom key overrides on the specialized policies.
The existing test changes (torch.tensor() removal from assert_close calls) are an unrelated cleanup and should be a separate PR.
| from tensordict.nn import TensorDictModuleBase | ||
|
|
||
|
|
||
| class _ScoreFactory: |
There was a problem hiding this comment.
Yeah this one is related to the scores function PR which gave some errors while executing it with the other policies
Description
Describe your changes in detail.
Added MCTSPolicyBase, MCTSPolicy, AlphaGoPolicy, AlphaStarPolicy, and MuZeroPolicy to mcts policies.
Motivation and Context
Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax
close #15213if this solves the issue #15213This issue doesn't close but is a subtask of #2357
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
xin all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!