-
Notifications
You must be signed in to change notification settings - Fork 433
[Feature] Added EXP3 Scoring function in continuation with pr #2358 #3013
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3013
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Cancelled Job, 1 Unrelated FailureAs of commit 7344bcb with merge base 9d9f6cb ( NEW FAILURE - The following job has failed:
CANCELLED JOB - The following job was cancelled. Please retry:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
vmoens
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG
Just need tests, docstrings and add it to the docs (see docs/ directory where you'll need to manually add the classes where they fit, I can help if it's unclear).
|
@vmoens I branched out from your PR branch. I will be adding more docstrings, tests and also the other two methods which are yet to be implemented |
|
Yes @vmoens , all the tests that I have added are related to all the three scoring functions |
|
Can you look at #6b57d53? Here's a detailed breakdown of what I changed: 1.
|
| File | Issue | Fix |
|---|---|---|
scores.py |
Unused nn import |
Removed |
scores.py |
Missing warnings import |
Added |
scores.py |
MCTSScore missing type hints |
Added type annotation and docstring |
scores.py |
PUCTScore.forward batch broadcasting |
Added n_total.unsqueeze(-1) when needed |
scores.py |
UCBScore.forward batch broadcasting |
Added n_total.unsqueeze(-1) when needed |
scores.py |
EXP3Score.forward batched num_actions |
Handle batched tensors with flatten()[0] |
scores.py |
Error message format | Fixed comma/space consistency |
scores.py |
update_weights exceptions vs warnings |
Changed to warnings.warn + fixed dead code |
scores.py |
UCB1TunedScore clamp not saved |
Changed to v_i_v = v_i_v.clamp(min=0) |
scores.py |
PUCT_VARIANT placeholder |
Removed |
modules/__init__.py |
Missing module exports | Added mcts imports and __all__ entries |
test_mcts.py |
create_node batch handling |
Use torch.full() for proper batch dims |
test_mcts.py |
Missing UCB1TunedScore tests | Added full test class |
6b57d53 to
4c136e6
Compare
4c136e6 to
a3ca7ae
Compare
Description
Added EXP3 Scoring function in continuation with pr #2358
Motivation and Context
Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax
close #15213if this solves the issue #15213Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
xin all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!