Skip to content

Commit 1eccb49

Browse files
Ibinarriaga8jorge.ibinarriaga.robles.becasvmoens
authored
[Algorithm] SOTA discrete offline CQL (#3098)
Co-authored-by: jorge.ibinarriaga.robles.becas <[email protected]> Co-authored-by: vmoens <[email protected]>
1 parent 009f4ce commit 1eccb49

File tree

9 files changed

+344
-7
lines changed

9 files changed

+344
-7
lines changed

.github/unittest/linux_sota/scripts/environment.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,5 @@ dependencies:
2929
- coverage
3030
- vmas
3131
- transformers
32+
- minari
33+
- minari[create]

.github/unittest/linux_sota/scripts/test_sota.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,14 @@
105105
collector.env_per_collector=2 \
106106
replay_buffer.size=120 \
107107
logger.backend=
108+
""",
109+
"discrete_cql_offline": """python sota-implementations/cql/discrete_cql_offline.py \
110+
collector.total_frames=48 \
111+
collector.init_random_frames=10 \
112+
collector.frames_per_batch=16 \
113+
collector.env_per_collector=2 \
114+
replay_buffer.batch_size=10 \
115+
logger.backend=
108116
""",
109117
"redq": """python sota-implementations/redq/redq.py \
110118
num_workers=4 \

sota-check/run_discrete_cql.sh

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#!/bin/bash
2+
3+
#SBATCH --job-name=cql_discrete_offline
4+
#SBATCH --ntasks=32
5+
#SBATCH --cpus-per-task=1
6+
#SBATCH --gres=gpu:1
7+
#SBATCH --output=slurm_logs/cql_discrete_offline_%j.txt
8+
#SBATCH --error=slurm_errors/cql_discrete_offline_%j.txt
9+
10+
current_commit=$(git rev-parse --short HEAD)
11+
project_name="torchrl-example-check-$current_commit"
12+
group_name="cql_discrete_offline"
13+
14+
export PYTHONPATH=$(dirname $(dirname $PWD))
15+
python $PYTHONPATH/sota-implementations/cql/discrete_cql_offline.py \
16+
logger.backend=wandb \
17+
logger.project_name="$project_name" \
18+
logger.group_name="$group_name"
19+
20+
# Capture the exit status of the Python command
21+
exit_status=$?
22+
# Write the exit status to a file
23+
if [ $exit_status -eq 0 ]; then
24+
echo "${group_name}_${SLURM_JOB_ID}=success" >>> report.log
25+
else
26+
echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log
27+
fi
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""CQL Example.
7+
8+
This is a self-contained example of a discrete offline CQL training script.
9+
10+
The helper functions are coded in the utils.py associated with this script.
11+
"""
12+
from __future__ import annotations
13+
14+
import warnings
15+
16+
import hydra
17+
import numpy as np
18+
import torch
19+
import tqdm
20+
from tensordict.nn import CudaGraphModule
21+
from torchrl._utils import timeit
22+
from torchrl.envs.utils import ExplorationType, set_exploration_type
23+
from torchrl.record.loggers import generate_exp_name, get_logger
24+
from utils import (
25+
dump_video,
26+
log_metrics,
27+
make_discrete_cql_optimizer,
28+
make_discrete_loss,
29+
make_discretecql_model,
30+
make_environment,
31+
make_offline_discrete_replay_buffer,
32+
)
33+
34+
torch.set_float32_matmul_precision("high")
35+
36+
37+
@hydra.main(version_base="1.1", config_path="", config_name="discrete_offline_config")
38+
def main(cfg): # noqa: F821
39+
device = cfg.optim.device
40+
if device in ("", None):
41+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
42+
device = torch.device(device)
43+
44+
# Create logger
45+
exp_name = generate_exp_name("DiscreteCQL", cfg.logger.exp_name)
46+
logger = None
47+
if cfg.logger.backend:
48+
logger = get_logger(
49+
logger_type=cfg.logger.backend,
50+
logger_name="discretecql_logging",
51+
experiment_name=exp_name,
52+
wandb_kwargs={
53+
"mode": cfg.logger.mode,
54+
"config": dict(cfg),
55+
"project": cfg.logger.project_name,
56+
"group": cfg.logger.group_name,
57+
},
58+
)
59+
60+
# Set seeds
61+
torch.manual_seed(cfg.env.seed)
62+
np.random.seed(cfg.env.seed)
63+
if cfg.env.seed is not None:
64+
warnings.warn(
65+
"The seed in the environment config is deprecated. "
66+
"Please set the seed in the optim config instead."
67+
)
68+
69+
# Create replay buffer
70+
replay_buffer = make_offline_discrete_replay_buffer(cfg.replay_buffer)
71+
72+
# Create env
73+
train_env, eval_env = make_environment(
74+
cfg, train_num_envs=1, eval_num_envs=cfg.logger.eval_envs, logger=logger
75+
)
76+
77+
# Create agent
78+
model, explore_policy = make_discretecql_model(cfg, train_env, eval_env, device)
79+
80+
del train_env
81+
82+
# Create loss
83+
loss_module, target_net_updater = make_discrete_loss(cfg.loss, model, device)
84+
85+
# Create optimizers
86+
optimizer = make_discrete_cql_optimizer(cfg, loss_module) # optimizer for CQL loss
87+
88+
def update(data):
89+
90+
# Compute loss components
91+
loss_vals = loss_module(data)
92+
93+
q_loss = loss_vals["loss_qvalue"]
94+
cql_loss = loss_vals["loss_cql"]
95+
96+
# Total loss = Q-learning loss + CQL regularization
97+
loss = q_loss + cql_loss
98+
99+
loss.backward()
100+
optimizer.step()
101+
optimizer.zero_grad(set_to_none=True)
102+
103+
# Soft update of target Q-network
104+
target_net_updater.step()
105+
106+
# Detach to avoid keeping computation graph in logging
107+
return loss.detach(), loss_vals.detach()
108+
109+
compile_mode = None
110+
if cfg.compile.compile:
111+
if cfg.compile.compile_mode not in (None, ""):
112+
compile_mode = cfg.compile.compile_mode
113+
elif cfg.compile.cudagraphs:
114+
compile_mode = "default"
115+
else:
116+
compile_mode = "reduce-overhead"
117+
update = torch.compile(update, mode=compile_mode)
118+
if cfg.compile.cudagraphs:
119+
warnings.warn(
120+
"CudaGraphModule es experimental y puede llevar a resultados incorrectos silenciosamente. Úsalo con precaución.",
121+
category=UserWarning,
122+
)
123+
update = CudaGraphModule(update, warmup=50)
124+
125+
pbar = tqdm.tqdm(total=cfg.optim.gradient_steps)
126+
127+
gradient_steps = cfg.optim.gradient_steps
128+
policy_eval_start = cfg.optim.policy_eval_start
129+
evaluation_interval = cfg.logger.eval_iter
130+
eval_steps = cfg.logger.eval_steps
131+
132+
# Training loop
133+
policy_eval_start = torch.tensor(policy_eval_start, device=device)
134+
for i in range(gradient_steps):
135+
timeit.printevery(1000, gradient_steps, erase=True)
136+
pbar.update(1)
137+
# sample data
138+
with timeit("sample"):
139+
data = replay_buffer.sample()
140+
141+
with timeit("update"):
142+
torch.compiler.cudagraph_mark_step_begin()
143+
loss, loss_vals = update(data.to(device))
144+
145+
# log metrics
146+
metrics_to_log = {
147+
"loss": loss.cpu(),
148+
**loss_vals.cpu(),
149+
}
150+
151+
# evaluation
152+
with timeit("log/eval"):
153+
if i % evaluation_interval == 0:
154+
with set_exploration_type(
155+
ExplorationType.DETERMINISTIC
156+
), torch.no_grad():
157+
eval_td = eval_env.rollout(
158+
max_steps=eval_steps,
159+
policy=explore_policy,
160+
auto_cast_to_device=True,
161+
)
162+
eval_env.apply(dump_video)
163+
164+
# eval_td: matrix of shape: [num_episodes, max_steps, ...]
165+
eval_reward = (
166+
eval_td["next", "reward"].sum(1).mean().item()
167+
) # mean computed over the sum of rewards for each episode
168+
metrics_to_log["evaluation_reward"] = eval_reward
169+
170+
with timeit("log"):
171+
metrics_to_log.update(timeit.todict(prefix="time"))
172+
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
173+
log_metrics(logger, metrics_to_log, i)
174+
175+
pbar.close()
176+
if not eval_env.is_closed:
177+
eval_env.close()
178+
179+
180+
if __name__ == "__main__":
181+
main()

sota-implementations/cql/discrete_cql_online.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
torch.set_float32_matmul_precision("high")
3737

3838

39-
@hydra.main(version_base="1.1", config_path="", config_name="discrete_cql_config")
39+
@hydra.main(version_base="1.1", config_path="", config_name="discrete_online_config")
4040
def main(cfg: DictConfig): # noqa: F821
4141
device = cfg.optim.device
4242
if device in ("", None):
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# env and task
2+
env:
3+
name: CartPole-v1 # CartPole environment for discrete action space
4+
task: ""
5+
library: minari
6+
n_samples_stats: 1000
7+
seed: 0
8+
backend: gymnasium
9+
10+
# Collector
11+
collector:
12+
frames_per_batch: 200
13+
total_frames: 1_000_000
14+
multi_step: 0
15+
init_random_frames: 1000
16+
env_per_collector: 1
17+
device:
18+
max_frames_per_traj: 200
19+
annealing_frames: 10000
20+
eps_start: 1.0
21+
eps_end: 0.01
22+
23+
24+
# logger
25+
logger:
26+
backend: wandb
27+
project_name: torchrl_example_cql
28+
group_name: null
29+
exp_name: cql_${replay_buffer.dataset}
30+
eval_iter: 5000 # eval interval in gradient steps
31+
eval_steps: 1000 # evaluation steps per eval
32+
mode: online
33+
eval_envs: 5 # number of evaluation environments
34+
video: True
35+
36+
# replay buffer
37+
replay_buffer:
38+
env: CartPole-v1
39+
dataset: CartPole-v2-random-v1
40+
batch_size: 128
41+
episodes: 10000
42+
43+
# optimization
44+
optim:
45+
device: null
46+
lr: 3e-4 # learning rate
47+
weight_decay: 0.0
48+
gradient_steps: 100_000
49+
policy_eval_start: 40_000
50+
51+
# model
52+
model:
53+
hidden_sizes: [256, 256]
54+
activation: relu
55+
56+
# loss
57+
loss:
58+
loss_function: l2
59+
gamma: 0.99
60+
tau: 0.005
61+
action_space: categorical
62+
63+
compile:
64+
compile: False
65+
compile_mode:
66+
cudagraphs: False

sota-implementations/cql/utils.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,49 @@ def make_offline_replay_buffer(rb_cfg):
195195
return data
196196

197197

198+
def make_offline_discrete_replay_buffer(rb_cfg):
199+
import gymnasium as gym
200+
import minari
201+
from minari import DataCollector
202+
203+
# Create custom minari dataset from environment
204+
205+
env = gym.make(rb_cfg.env)
206+
env = DataCollector(env)
207+
208+
for _ in range(rb_cfg.episodes):
209+
env.reset(seed=123)
210+
while True:
211+
action = env.action_space.sample()
212+
obs, rew, terminated, truncated, info = env.step(action)
213+
if terminated or truncated:
214+
break
215+
216+
env.create_dataset(
217+
dataset_id=rb_cfg.dataset,
218+
algorithm_name="Random-Policy",
219+
code_permalink="https://github.com/Farama-Foundation/Minari",
220+
author="Farama",
221+
author_email="[email protected]",
222+
)
223+
224+
data = MinariExperienceReplay(
225+
dataset_id=rb_cfg.dataset,
226+
split_trajs=False,
227+
batch_size=rb_cfg.batch_size,
228+
load_from_local_minari=True,
229+
sampler=SamplerWithoutReplacement(drop_last=True),
230+
prefetch=4,
231+
)
232+
233+
data.append_transform(DoubleToFloat())
234+
235+
# Clean up
236+
minari.delete_dataset(rb_cfg.dataset)
237+
238+
return data
239+
240+
198241
# ====================================================================
199242
# Model
200243
# -----
@@ -354,11 +397,21 @@ def make_continuous_loss(loss_cfg, model, device: torch.device | None = None):
354397

355398

356399
def make_discrete_loss(loss_cfg, model, device: torch.device | None = None):
357-
loss_module = DiscreteCQLLoss(
358-
model,
359-
loss_function=loss_cfg.loss_function,
360-
delay_value=True,
361-
)
400+
401+
if "action_space" in loss_cfg: # especify action space
402+
loss_module = DiscreteCQLLoss(
403+
model,
404+
loss_function=loss_cfg.loss_function,
405+
action_space=loss_cfg.action_space,
406+
delay_value=True,
407+
)
408+
else:
409+
loss_module = DiscreteCQLLoss(
410+
model,
411+
loss_function=loss_cfg.loss_function,
412+
delay_value=True,
413+
)
414+
362415
loss_module.make_value_estimator(gamma=loss_cfg.gamma, device=device)
363416
target_net_updater = SoftUpdate(loss_module, tau=loss_cfg.tau)
364417

test/llm/test_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2386,7 +2386,7 @@ def test_batching_continuous_throughput(
23862386
assert len(processing_events) > 0, "No processing occurred"
23872387

23882388
# Check that processing happened across multiple threads (indicating concurrent processing)
2389-
thread_ids = set(event["thread_id"] for event in processing_events)
2389+
thread_ids = {event["thread_id"] for event in processing_events} # noqa
23902390
assert (
23912391
len(thread_ids) > 1
23922392
), f"All processing happened in single thread: {thread_ids}"

0 commit comments

Comments
 (0)