Skip to content

Commit b452384

Browse files
authored
Transition from Models Hub to Datasets Hub for expert trajectories (#723)
* Remove load_rolluts_from_huggingface and replace it with code in demonstrations.py that loads demonstrations from huggingface datasets instead of huggingface models. * Allow specifying the repo_id directly in the loader_kwargs of the demonstrations ingredient and pass remaining loader_kwargs to datasets.load_dataset. * Simplify demonstrations ingredient configuration and make it more flexible at the same time. * Remove now obsolete test. * Rename rollout_type to type and rollout_path to path and make default type "generated" to match previous behavior. * Fix documentation of the Raises: section of get_exper_trajectories() and improve wording of ValueError when n_expert_demos is missing while generating trajectories. * Simplify unnecessarily complex regexes to match raised exceptions during testing. * Fix regex for ValueError to reflect updated ValueError string. * Add an edge case to accommodate the fact that the HuggingFace Hub only has an expert for seals/Cartpole while the testdata folder only has an expert for normal Cartpole. * Make it explicit that in some tests the rollout should be loaded locally from disk. * Fix formatting issues in test_scripts.py * Rename demonstrations.type to demonstrations.source to overcome name-clash with build-in keyword of python. * Make sure to load local demonstrations in quickstart.sh * Fix formatting issue. * Ensure the readme contains the same snippet as the examples.
1 parent 688e163 commit b452384

27 files changed

+188
-187
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,10 @@ From [examples/quickstart.sh:](examples/quickstart.sh)
7474
python -m imitation.scripts.train_rl with pendulum environment.fast policy_evaluation.fast rl.fast fast logging.log_dir=quickstart/rl/
7575

7676
# Train GAIL from demonstrations. Tensorboard logs saved in output/ (default log directory).
77-
python -m imitation.scripts.train_adversarial gail with pendulum environment.fast demonstrations.fast policy_evaluation.fast rl.fast fast demonstrations.rollout_path=quickstart/rl/rollouts/final.npz
77+
python -m imitation.scripts.train_adversarial gail with pendulum environment.fast demonstrations.fast policy_evaluation.fast rl.fast fast demonstrations.path=quickstart/rl/rollouts/final.npz demonstrations.source=local
7878

7979
# Train AIRL from demonstrations. Tensorboard logs saved in output/ (default log directory).
80-
python -m imitation.scripts.train_adversarial airl with pendulum environment.fast demonstrations.fast policy_evaluation.fast rl.fast fast demonstrations.rollout_path=quickstart/rl/rollouts/final.npz
80+
python -m imitation.scripts.train_adversarial airl with pendulum environment.fast demonstrations.fast policy_evaluation.fast rl.fast fast demonstrations.path=quickstart/rl/rollouts/final.npz demonstrations.source=local
8181
```
8282

8383
Tips:

benchmarking/example_airl_seals_ant_best_hp_eval.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
},
77
"checkpoint_interval": 0,
88
"demonstrations": {
9-
"rollout_type": "ppo-huggingface",
9+
"source": "huggingface",
10+
"algo_name": "ppo",
1011
"n_expert_demos": null
1112
},
1213
"reward": {

benchmarking/example_airl_seals_half_cheetah_best_hp_eval.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
},
77
"checkpoint_interval": 0,
88
"demonstrations": {
9-
"rollout_type": "ppo-huggingface",
9+
"source": "huggingface",
10+
"algo_name": "ppo",
1011
"n_expert_demos": null
1112
},
1213
"reward": {

benchmarking/example_airl_seals_hopper_best_hp_eval.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
},
77
"checkpoint_interval": 0,
88
"demonstrations": {
9-
"rollout_type": "ppo-huggingface",
9+
"source": "huggingface",
10+
"algo_name": "ppo",
1011
"n_expert_demos": null
1112
},
1213
"reward": {

benchmarking/example_airl_seals_swimmer_best_hp_eval.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
},
77
"checkpoint_interval": 0,
88
"demonstrations": {
9-
"rollout_type": "ppo-huggingface",
9+
"source": "huggingface",
10+
"algo_name": "ppo",
1011
"n_expert_demos": null
1112
},
1213
"expert": {

benchmarking/example_airl_seals_walker_best_hp_eval.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
},
77
"checkpoint_interval": 0,
88
"demonstrations": {
9-
"rollout_type": "ppo-huggingface",
9+
"source": "huggingface",
10+
"algo_name": "ppo",
1011
"n_expert_demos": null
1112
},
1213
"expert": {

benchmarking/example_bc_seals_ant_best_hp_eval.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
"use_offline_rollouts": false
2121
},
2222
"demonstrations": {
23-
"rollout_type": "ppo-huggingface",
23+
"source": "huggingface",
24+
"algo_name": "ppo",
2425
"n_expert_demos": null
2526
},
2627
"policy": {

benchmarking/example_bc_seals_half_cheetah_best_hp_eval.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
"use_offline_rollouts": false
2121
},
2222
"demonstrations": {
23-
"rollout_type": "ppo-huggingface",
23+
"source": "huggingface",
24+
"algo_name": "ppo",
2425
"n_expert_demos": null
2526
},
2627
"policy": {

benchmarking/example_bc_seals_hopper_best_hp_eval.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
"use_offline_rollouts": false
2121
},
2222
"demonstrations": {
23-
"rollout_type": "ppo-huggingface",
23+
"source": "huggingface",
24+
"algo_name": "ppo",
2425
"n_expert_demos": null
2526
},
2627
"policy": {

benchmarking/example_bc_seals_swimmer_best_hp_eval.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
"use_offline_rollouts": false
2121
},
2222
"demonstrations": {
23-
"rollout_type": "ppo-huggingface",
23+
"source": "huggingface",
24+
"algo_name": "ppo",
2425
"n_expert_demos": null
2526
},
2627
"policy": {

0 commit comments

Comments
 (0)