Skip to content

Commit 2d5db11

Browse files
Update Hugging Face 🤗 Push To Hub (#379)
* Update push_to_hub.py * Only render when needed * Patch Atari game video recording * Patch atari rendering * Remove Atari patch, will be fixed by SB3 update * Fix record video steps + update comments * Update versions --------- Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
1 parent 382f4b4 commit 2d5db11

File tree

6 files changed

+34
-19
lines changed

6 files changed

+34
-19
lines changed

CHANGELOG.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
1-
## Release 2.0.0a9 (WIP)
1+
## Release 2.0.0a12 (WIP)
22

33
### Breaking Changes
44
- Upgraded to gym 0.26+
55
- Fixed bug in HistoryWrapper, now returns the correct obs space limits
66
- Upgraded to SB3 >= 2.0.0
7+
- Upgraded to Huggingface-SB3 >= 2.2.5
78

89
### New Features
910
- Gym 0.26+ patches to continue working with pybullet and TimeLimit wrapper
1011

1112
### Bug fixes
12-
- Renamed ``CarRacing-v1`` to ``CarRacing-v2`` in hyperparameters
13+
- Renamed `CarRacing-v1` to `CarRacing-v2` in hyperparameters
14+
- Huggingface push to hub now accepts a `--n-timesteps` argument to adjust the length of the video
15+
- Fixed `record_video` steps (before it was stepping in a closed env)
1316

1417
## Release 1.8.0 (2023-04-07)
1518

requirements.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
gym==0.26.2
2-
stable-baselines3[extra_no_roms,tests,docs]>=2.0.0a9
3-
sb3-contrib>=2.0.0a9
2+
stable-baselines3[extra_no_roms,tests,docs]>=2.0.0a13
3+
sb3-contrib>=2.0.0a13
44
box2d-py==2.3.8
55
pybullet
66
# minigrid
@@ -14,7 +14,7 @@ plotly
1414
# panda-gym~=3.0.1
1515
rliable>=1.0.5
1616
wandb
17-
huggingface_sb3>=2.2.1
17+
huggingface_sb3>=2.2.5
1818
seaborn
1919
tqdm
2020
rich

rl_zoo3/push_to_hub.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
from wasabi import Printer
2121

2222
import rl_zoo3.import_envs # noqa: F401 pylint: disable=unused-import
23-
from rl_zoo3 import ALGOS, create_test_env, get_saved_hyperparams
23+
from rl_zoo3 import ALGOS, get_saved_hyperparams
2424
from rl_zoo3.exp_manager import ExperimentManager
25-
from rl_zoo3.utils import StoreDict, get_model_path
25+
from rl_zoo3.utils import StoreDict, create_test_env, get_model_path
2626

2727
msg = Printer()
2828

@@ -277,12 +277,12 @@ def package_to_hub(
277277

278278
if __name__ == "__main__":
279279
parser = argparse.ArgumentParser()
280-
parser.add_argument("--env", help="environment ID", type=EnvironmentName, required=True)
280+
parser.add_argument("--env", help="Environment ID", type=EnvironmentName, required=True)
281281
parser.add_argument("-f", "--folder", help="Log folder", type=str, required=True)
282282
parser.add_argument("--algo", help="RL Algorithm", type=str, required=True, choices=list(ALGOS.keys()))
283-
parser.add_argument("-n", "--n-timesteps", help="number of timesteps", default=1000, type=int)
283+
parser.add_argument("-n", "--n-timesteps", help="Number of timesteps for the video recording", default=1000, type=int)
284284
parser.add_argument("--num-threads", help="Number of threads for PyTorch (-1 to use default)", default=-1, type=int)
285-
parser.add_argument("--n-envs", help="number of environments", default=1, type=int)
285+
parser.add_argument("--n-envs", help="Number of environments", default=1, type=int)
286286
parser.add_argument("--exp-id", help="Experiment ID (default: 0: latest, -1: no exp folder)", default=0, type=int)
287287
parser.add_argument("--verbose", help="Verbose mode (0: no output, 1: INFO)", default=1, type=int)
288288
parser.add_argument(
@@ -357,6 +357,12 @@ def package_to_hub(
357357
loaded_args = yaml.load(f, Loader=yaml.UnsafeLoader) # pytype: disable=module-attr
358358
if loaded_args["env_kwargs"] is not None:
359359
env_kwargs = loaded_args["env_kwargs"]
360+
361+
# render and record video by default
362+
should_render = not args.no_render
363+
if should_render:
364+
env_kwargs.update(render_mode="rgb_array")
365+
360366
# overwrite with command line arguments
361367
if args.env_kwargs is not None:
362368
env_kwargs.update(args.env_kwargs)
@@ -367,7 +373,7 @@ def package_to_hub(
367373
stats_path=maybe_stats_path,
368374
seed=args.seed,
369375
log_dir=None,
370-
should_render=not args.no_render,
376+
should_render=should_render,
371377
hyperparams=deepcopy(hyperparams),
372378
env_kwargs=env_kwargs,
373379
)
@@ -377,6 +383,12 @@ def package_to_hub(
377383
# Dummy buffer size as we don't need memory to enjoy the trained agent
378384
kwargs.update(dict(buffer_size=1))
379385

386+
# Hack due to breaking change in v1.6
387+
# handle_timeout_termination cannot be at the same time
388+
# with optimize_memory_usage
389+
if "optimize_memory_usage" in hyperparams:
390+
kwargs.update(optimize_memory_usage=False)
391+
380392
# Note: we assume that we push models using the same machine (same python version)
381393
# that trained them, if not, we would need to pass custom object as in enjoy.py
382394
custom_objects: Dict[str, Any] = {}
@@ -411,6 +423,6 @@ def package_to_hub(
411423
n_eval_episodes=10,
412424
token=None,
413425
local_repo_path="hub",
414-
video_length=1000,
426+
video_length=args.n_timesteps,
415427
generate_video=not args.no_render,
416428
)

rl_zoo3/record_video.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313

1414
if __name__ == "__main__":
1515
parser = argparse.ArgumentParser()
16-
parser.add_argument("--env", help="environment ID", type=EnvironmentName, default="CartPole-v1")
16+
parser.add_argument("--env", help="Environment ID", type=EnvironmentName, default="CartPole-v1")
1717
parser.add_argument("-f", "--folder", help="Log folder", type=str, default="rl-trained-agents")
1818
parser.add_argument("-o", "--output-folder", help="Output folder", type=str)
1919
parser.add_argument("--algo", help="RL Algorithm", default="ppo", type=str, required=False, choices=list(ALGOS.keys()))
20-
parser.add_argument("-n", "--n-timesteps", help="number of timesteps", default=1000, type=int)
21-
parser.add_argument("--n-envs", help="number of environments", default=1, type=int)
20+
parser.add_argument("-n", "--n-timesteps", help="Number of timesteps", default=1000, type=int)
21+
parser.add_argument("--n-envs", help="Number of environments", default=1, type=int)
2222
parser.add_argument("--deterministic", action="store_true", default=False, help="Use deterministic actions")
2323
parser.add_argument("--stochastic", action="store_true", default=False, help="Use stochastic actions")
2424
parser.add_argument("--seed", help="Random generator seed", type=int, default=0)
@@ -150,7 +150,7 @@
150150
lstm_states = None
151151
episode_starts = np.ones((env.num_envs,), dtype=bool)
152152
try:
153-
for _ in range(video_length + 1):
153+
for _ in range(video_length):
154154
action, lstm_states = model.predict(
155155
obs, # type: ignore[arg-type]
156156
state=lstm_states,

rl_zoo3/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.0.0a9
1+
2.0.0a13

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
},
2828
entry_points={"console_scripts": ["rl_zoo3=rl_zoo3.cli:main"]},
2929
install_requires=[
30-
"sb3_contrib>=2.0.0a9",
30+
"sb3_contrib>=2.0.0a13",
3131
"gym==0.26.2",
32-
"huggingface_sb3>=2.2.1",
32+
"huggingface_sb3>=2.2.5",
3333
"tqdm",
3434
"rich",
3535
"optuna",

0 commit comments

Comments
 (0)