Skip to content

Commit 5e2e9a6

Browse files
committed
[Algorithm] PILCO
1 parent 4d2c3cb commit 5e2e9a6

File tree

4 files changed

+819
-0
lines changed

4 files changed

+819
-0
lines changed

sota-check/run_pilco.sh

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#!/bin/bash
2+
3+
#SBATCH --job-name=pilco
4+
#SBATCH --ntasks=32
5+
#SBATCH --cpus-per-task=1
6+
#SBATCH --gres=gpu:1
7+
#SBATCH --output=slurm_logs/pilco_%j.txt
8+
#SBATCH --error=slurm_errors/pilco_%j.txt
9+
10+
current_commit=$(git rev-parse --short HEAD)
11+
project_name="torchrl-example-check-$current_commit"
12+
group_name="pilco"
13+
export PYTHONPATH=$(dirname $(dirname $PWD))
14+
python $PYTHONPATH/sota-implementations/pilco/pilco.py \
15+
logger.backend=wandb \
16+
logger.project_name="$project_name" \
17+
logger.group_name="$group_name"
18+
19+
# Capture the exit status of the Python command
20+
exit_status=$?
21+
# Write the exit status to a file
22+
if [ $exit_status -eq 0 ]; then
23+
echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log
24+
else
25+
echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log
26+
fi
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
env:
2+
env_name: InvertedPendulum-v5
3+
library: gym
4+
device: null
5+
logger:
6+
backend: wandb
7+
project_name: torchrl_pilco
8+
group_name: null
9+
video: True
10+
optim:
11+
policy_lr: 5e-3
12+
pilco:
13+
horizon: 40
14+
initial_rollout_length: 200
15+
max_rollout_length: 350
16+
epochs: 3
17+
policy_training_steps: 100
18+
policy_n_basis: 10
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
import hydra
2+
import tensordict
3+
import torch
4+
from omegaconf import DictConfig
5+
6+
from tensordict import TensorDict, TensorDictBase
7+
from tensordict.nn import TensorDictModule
8+
from torchrl._utils import get_available_device
9+
from torchrl.envs import EnvBase
10+
from torchrl.envs.utils import RandomPolicy
11+
from torchrl.record.loggers import generate_exp_name, get_logger, Logger
12+
13+
from utils import (
14+
BoTorchGPWorldModel,
15+
ImaginedEnv,
16+
make_env,
17+
pendulum_cost,
18+
RBFController,
19+
)
20+
21+
22+
def pilco_loop(
23+
cfg: DictConfig, env: EnvBase, logger: Logger | None = None
24+
) -> TensorDictModule:
25+
obs_dim = env.observation_spec["observation"].shape[-1]
26+
action_dim = env.action_spec.shape[-1]
27+
28+
random_policy = RandomPolicy(action_spec=env.action_spec)
29+
rollout = env.rollout(
30+
max_steps=cfg.pilco.initial_rollout_length,
31+
policy=random_policy,
32+
break_when_all_done=False,
33+
break_when_any_done=False,
34+
)
35+
36+
base_policy = (
37+
RBFController(
38+
input_dim=obs_dim,
39+
output_dim=action_dim,
40+
n_basis=cfg.pilco.policy_n_basis,
41+
max_action=env.action_spec.high,
42+
)
43+
.to(env.device)
44+
.double()
45+
)
46+
policy_module = TensorDictModule(
47+
module=base_policy,
48+
in_keys=[("observation", "mean"), ("observation", "var")],
49+
out_keys=[
50+
("action", "mean"),
51+
("action", "var"),
52+
("action", "cross_covariance"),
53+
],
54+
)
55+
optimizer = torch.optim.Adam(policy_module.parameters(), lr=cfg.optim.policy_lr)
56+
57+
dtype = torch.float64
58+
initial_observation = TensorDict(
59+
{
60+
("observation", "mean"): torch.zeros(
61+
obs_dim, device=env.device, dtype=dtype
62+
),
63+
("observation", "var"): torch.eye(obs_dim, device=env.device, dtype=dtype)
64+
* 1e-3,
65+
}
66+
)
67+
68+
for epoch in range(cfg.pilco.epochs):
69+
base_world_model = BoTorchGPWorldModel(
70+
obs_dim=obs_dim, action_dim=action_dim
71+
).to(env.device)
72+
base_world_model.fit(rollout)
73+
base_world_model.freeze_and_detach()
74+
75+
world_model_module = TensorDictModule(
76+
module=base_world_model,
77+
in_keys=["action", "observation"],
78+
out_keys=[("next_observation", "mean"), ("next_observation", "var")],
79+
)
80+
81+
imagined_env = ImaginedEnv(
82+
world_model_module=world_model_module,
83+
base_env=env,
84+
)
85+
reset_td = initial_observation.expand(*imagined_env.batch_size)
86+
87+
for step in range(cfg.pilco.policy_training_steps):
88+
logger_step = (epoch * cfg.pilco.policy_training_steps) + step
89+
optimizer.zero_grad()
90+
91+
imagined_data = imagined_env.rollout(
92+
max_steps=cfg.pilco.horizon,
93+
policy=policy_module,
94+
tensordict=reset_td,
95+
)
96+
97+
obs = imagined_data["observation"]
98+
cost = pendulum_cost(obs)
99+
loss = cost.mean()
100+
101+
loss.backward()
102+
optimizer.step()
103+
104+
if logger:
105+
logger.log_scalar(
106+
"train/trajectory_cost", loss.item(), step=logger_step
107+
)
108+
109+
def policy_for_env(td: TensorDictBase) -> TensorDictBase:
110+
obs = td["observation"]
111+
device, dtype = obs.device, obs.dtype
112+
113+
is_unbatched = obs.ndim == 1
114+
if is_unbatched:
115+
obs = obs.unsqueeze(0)
116+
117+
batch_shape = obs.shape[:-1]
118+
D = obs.shape[-1]
119+
120+
policy_in = TensorDict(
121+
{
122+
"observation": TensorDict(
123+
{
124+
"mean": obs,
125+
"var": torch.zeros(
126+
(*batch_shape, D, D), device=device, dtype=dtype
127+
),
128+
},
129+
batch_size=batch_shape,
130+
)
131+
},
132+
batch_size=batch_shape,
133+
device=device,
134+
)
135+
136+
policy_out = policy_module(policy_in)
137+
action_mean = policy_out["action", "mean"]
138+
139+
if is_unbatched:
140+
action_mean = action_mean.squeeze(0)
141+
142+
td["action"] = action_mean
143+
return td
144+
145+
test_rollout = env.rollout(
146+
max_steps=1000, policy=policy_for_env, break_when_any_done=True
147+
)
148+
149+
reward = test_rollout["episode_reward"][-1].item()
150+
steps = test_rollout["step_count"].max().item()
151+
152+
if logger:
153+
logger.log_scalar("eval/reward", reward, step=logger_step)
154+
logger.log_scalar("eval/steps", steps, step=logger_step)
155+
156+
rollout = tensordict.cat([rollout, test_rollout], dim=0)
157+
158+
if len(rollout) > cfg.pilco.max_rollout_length:
159+
rollout = rollout[-cfg.pilco.max_rollout_length :]
160+
161+
return policy_module
162+
163+
164+
@hydra.main(config_path="", config_name="config", version_base="1.1")
165+
def main(cfg: DictConfig) -> None:
166+
device = torch.device(cfg.device) if cfg.device else get_available_device()
167+
168+
env = make_env(cfg.env.env_name, device, from_pixels=cfg.logger.video)
169+
170+
if cfg.logger.backend:
171+
exp_name = generate_exp_name("PILCO", cfg.env.env_name)
172+
logger = get_logger(
173+
cfg.logger.backend,
174+
logger_name="pilco",
175+
experiment_name=exp_name,
176+
wandb_kwargs={
177+
"config": dict(cfg),
178+
"project": cfg.logger.project_name,
179+
"group": cfg.logger.group_name,
180+
},
181+
)
182+
183+
pilco_loop(cfg, env, logger=logger)
184+
185+
if not env.is_closed:
186+
env.close()
187+
188+
189+
if __name__ == "__main__":
190+
main()

0 commit comments

Comments
 (0)