-
Notifications
You must be signed in to change notification settings - Fork 297
Expand file tree
/
Copy pathairl.py
More file actions
105 lines (84 loc) · 3.29 KB
/
airl.py
File metadata and controls
105 lines (84 loc) · 3.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""Config and run configuration for AIRL."""
import dataclasses
import logging
import pathlib
from typing import Any, Dict, Sequence, cast
import hydra
import torch as th
from hydra.core.config_store import ConfigStore
from hydra.utils import instantiate
from omegaconf import MISSING
from imitation.policies import serialize
from imitation_cli.algorithm_configurations import airl as airl_cfg
from imitation_cli.utils import environment as environment_cfg
from imitation_cli.utils import (
policy_evaluation,
randomness,
reward_network,
rl_algorithm,
trajectories,
)
@dataclasses.dataclass
class RunConfig:
"""Config for running AIRL."""
rng: randomness.Config = randomness.Config(seed=0)
total_timesteps: int = int(1e6)
checkpoint_interval: int = 0
environment: environment_cfg.Config = MISSING
airl: airl_cfg.Config = MISSING
evaluation: policy_evaluation.Config = MISSING
# This ensures that the working directory is changed
# to the hydra output dir
hydra: Any = dataclasses.field(default_factory=lambda: dict(job=dict(chdir=True)))
cs = ConfigStore.instance()
environment_cfg.register_configs("environment", "${rng}")
trajectories.register_configs("airl/demonstrations", "${environment}", "${rng}")
rl_algorithm.register_configs("airl/gen_algo", "${environment}", "${rng.seed}")
reward_network.register_configs("airl/reward_net", "${environment}")
policy_evaluation.register_configs("evaluation", "${environment}", "${rng}")
cs.store(
name="airl_run_base",
node=RunConfig(
airl=airl_cfg.Config(
venv="${environment}", # type: ignore[arg-type]
),
),
)
@hydra.main(
version_base=None,
config_path="config",
config_name="airl_run",
)
def run_airl(cfg: RunConfig) -> Dict[str, Any]:
from imitation.algorithms.adversarial import airl
from imitation.data import rollout
from imitation.data.types import TrajectoryWithRew
trainer: airl.AIRL = instantiate(cfg.airl)
checkpoints_path = pathlib.Path("checkpoints")
def save(path: str):
"""Save discriminator and generator."""
# We implement this here and not in Trainer since we do not want to actually
# serialize the whole Trainer (including e.g. expert demonstrations).
save_path = checkpoints_path / path
save_path.mkdir(parents=True, exist_ok=True)
th.save(trainer.reward_train, save_path / "reward_train.pt")
th.save(trainer.reward_test, save_path / "reward_test.pt")
serialize.save_stable_model(save_path / "gen_policy", trainer.gen_algo)
def callback(round_num: int, /) -> None:
if cfg.checkpoint_interval > 0 and round_num % cfg.checkpoint_interval == 0:
logging.log(logging.INFO, f"Saving checkpoint at round {round_num}")
save(f"{round_num:05d}")
trainer.train(cfg.total_timesteps, callback)
imit_stats = policy_evaluation.eval_policy(trainer.policy, cfg.evaluation)
# Save final artifacts.
if cfg.checkpoint_interval >= 0:
logging.log(logging.INFO, "Saving final checkpoint.")
save("final")
return {
"imit_stats": imit_stats,
"expert_stats": rollout.rollout_stats(
cast(Sequence[TrajectoryWithRew], trainer.get_demonstrations()),
),
}
if __name__ == "__main__":
run_airl()