Skip to content

Commit 90dda53

Browse files
ooctipuskellyguo11
andauthored
Enhances Pbt usage experience through small improvements (#3449)
# Description This PR is added with feedback from PBT user, and made below improvments 1. added resume logic to allow wandb to continue on the same run_id 2. corrected broadcasting order in distributed setup 3. made score query general by using dotted keys to access dictionary of arbitrary depth Fixes # (issue) - Bug fix (non-breaking change which fixes an issue) ## Screenshots Please attach before and after screenshots of the change if applicable. <!-- Example: | Before | After | | ------ | ----- | | _gif/png before_ | _gif/png after_ | To upload images to a PR -- simply drag and drop an image while in edit mode and it should upload the image directly. You can then paste that source into the above before/after sections. --> ## Checklist - [x] I have read and understood the [contribution guidelines](https://isaac-sim.github.io/IsaacLab/main/source/refs/contributing.html) - [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with `./isaaclab.sh --format` - [x] I have made corresponding changes to the documentation - [ ] My changes generate no new warnings - [] I have added tests that prove my fix is effective or that my feature works - [x] I have updated the changelog and the corresponding version in the extension's `config/extension.toml` file - [x] I have added my name to the `CONTRIBUTORS.md` or my name already exists there <!-- As you go through the checklist above, you can mark something as done by putting an x character in it For example, - [x] I have done this task - [ ] I have not done this task --> --------- Co-authored-by: Kelly Guo <[email protected]>
1 parent 187f9a5 commit 90dda53

File tree

5 files changed

+33
-12
lines changed

5 files changed

+33
-12
lines changed

docs/source/features/population_based_training.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ Example Config
4949
num_policies: 8
5050
directory: .
5151
workspace: "pbt_workspace"
52-
objective: Curriculum/difficulty_level
52+
objective: episode.Curriculum/difficulty_level
5353
interval_steps: 50000000
5454
threshold_std: 0.1
5555
threshold_abs: 0.025
@@ -66,9 +66,9 @@ Example Config
6666
agent.params.config.tau: "mutate_discount"
6767
6868
69-
``objective: Curriculum/difficulty_level`` uses ``infos["episode"]["Curriculum/difficulty_level"]`` as the scalar to
70-
**rank policies** (higher is better). With ``num_policies: 8``, launch eight processes sharing the same ``workspace``
71-
and unique ``policy_idx`` (0-7).
69+
``objective: episode.Curriculum/difficulty_level`` is the dotted expression that uses
70+
``infos["episode"]["Curriculum/difficulty_level"]`` as the scalar to **rank policies** (higher is better).
71+
With ``num_policies: 8``, launch eight processes sharing the same ``workspace`` and unique ``policy_idx`` (0-7).
7272

7373

7474
Launching PBT

scripts/reinforcement_learning/rl_games/train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,9 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
226226
monitor_gym=True,
227227
save_code=True,
228228
)
229-
wandb.config.update({"env_cfg": env_cfg.to_dict()})
230-
wandb.config.update({"agent_cfg": agent_cfg})
229+
if not wandb.run.resumed:
230+
wandb.config.update({"env_cfg": env_cfg.to_dict()})
231+
wandb.config.update({"agent_cfg": agent_cfg})
231232

232233
if args_cli.checkpoint is not None:
233234
runner.run({"train": True, "play": False, "sigma": train_sigma, "checkpoint": resume_path})

source/isaaclab_rl/config/extension.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[package]
22

33
# Note: Semantic Versioning is used: https://semver.org/
4-
version = "0.4.0"
4+
version = "0.4.1"
55

66
# Description
77
title = "Isaac Lab RL"

source/isaaclab_rl/docs/CHANGELOG.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
11
Changelog
22
---------
33

4+
0.4.1 (2025-09-09)
5+
~~~~~~~~~~~~~~~~~~
6+
7+
Fixed
8+
^^^^^
9+
10+
* Made PBT a bit nicer by
11+
* 1. added resume logic to allow wandb to continue on the same run_id
12+
* 2. corrected broadcasting order in distributed setup
13+
* 3. made score query general by using dotted keys to access dictionary of arbitrary depth
14+
15+
416
0.4.0 (2025-09-09)
517
~~~~~~~~~~~~~~~~~~
618

source/isaaclab_rl/isaaclab_rl/rl_games/pbt/pbt.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,12 @@ def process_infos(self, infos, done_indices):
6868
"""Extract the scalar objective from environment infos and store in `self.score`.
6969
7070
Notes:
71-
Expects the objective to be at `infos["episode"][self.cfg.objective]`.
71+
Expects the objective to be at `infos[self.cfg.objective]` where self.cfg.objective is dotted address.
7272
"""
73-
self.score = infos["episode"][self.cfg.objective]
73+
score = infos
74+
for part in self.cfg.objective.split("."):
75+
score = score[part]
76+
self.score = score
7477

7578
def after_steps(self):
7679
"""Main PBT tick executed every train step.
@@ -84,6 +87,9 @@ def after_steps(self):
8487
whitelisted params, set `restart_flag`, broadcast (if distributed),
8588
and print a mutation diff table.
8689
"""
90+
if self.distributed_args.distributed:
91+
dist.broadcast(self.restart_flag, src=0)
92+
8793
if self.distributed_args.rank != 0:
8894
if self.restart_flag.cpu().item() == 1:
8995
os._exit(0)
@@ -154,9 +160,6 @@ def after_steps(self):
154160
self.new_params = mutate(cur_params, self.cfg.mutation, self.cfg.mutation_rate, self.cfg.change_range)
155161
self.restart_from_checkpoint = os.path.abspath(ckpts[replacement_policy_candidate]["checkpoint"])
156162
self.restart_flag[0] = 1
157-
if self.distributed_args.distributed:
158-
dist.broadcast(self.restart_flag, src=0)
159-
160163
self.printer.print_mutation_diff(cur_params, self.new_params)
161164

162165
def _restart_with_new_params(self, new_params, restart_from_checkpoint):
@@ -191,6 +194,11 @@ def _restart_with_new_params(self, new_params, restart_from_checkpoint):
191194
if self.wandb_args.enabled:
192195
import wandb
193196

197+
# note setdefault will only affect child process, that mean don't have to worry it env variable
198+
# propagate beyond restarted child process
199+
os.environ.setdefault("WANDB_RUN_ID", wandb.run.id) # continue with the same run id
200+
os.environ.setdefault("WANDB_RESUME", "allow") # allow wandb to resume
201+
os.environ.setdefault("WANDB_INIT_TIMEOUT", "300") # give wandb init more time to be fault tolerant
194202
wandb.run.finish()
195203

196204
# Get the directory of the current file

0 commit comments

Comments
 (0)