9
9
from __future__ import annotations
10
10
11
11
import time
12
+ import warnings
12
13
13
14
import hydra
14
15
import numpy as np
15
16
import torch
16
17
import tqdm
17
- from torchrl ._utils import logger as torchrl_logger
18
+ from tensordict .nn import CudaGraphModule
19
+ from torchrl ._utils import logger as torchrl_logger , timeit
18
20
from torchrl .envs .libs .gym import set_gym_backend
19
-
20
21
from torchrl .envs .utils import ExplorationType , set_exploration_type
21
22
from torchrl .modules .tensordict_module import DecisionTransformerInferenceWrapper
22
23
from torchrl .record import VideoRecorder
@@ -65,8 +66,7 @@ def main(cfg: "DictConfig"): # noqa: F821
65
66
)
66
67
67
68
# Create policy model
68
- actor = make_odt_model (cfg )
69
- policy = actor .to (model_device )
69
+ policy = make_odt_model (cfg , device = model_device )
70
70
71
71
# Create loss
72
72
loss_module = make_odt_loss (cfg .loss , policy )
@@ -80,13 +80,46 @@ def main(cfg: "DictConfig"): # noqa: F821
80
80
inference_policy = DecisionTransformerInferenceWrapper (
81
81
policy = policy ,
82
82
inference_context = cfg .env .inference_context ,
83
- ).to (model_device )
83
+ device = model_device ,
84
+ )
84
85
inference_policy .set_tensor_keys (
85
86
observation = "observation_cat" ,
86
87
action = "action_cat" ,
87
88
return_to_go = "return_to_go_cat" ,
88
89
)
89
90
91
+ def update (data ):
92
+ transformer_optim .zero_grad (set_to_none = True )
93
+ temperature_optim .zero_grad (set_to_none = True )
94
+ # Compute loss
95
+ loss_vals = loss_module (data .to (model_device ))
96
+ transformer_loss = loss_vals ["loss_log_likelihood" ] + loss_vals ["loss_entropy" ]
97
+ temperature_loss = loss_vals ["loss_alpha" ]
98
+
99
+ (temperature_loss + transformer_loss ).backward ()
100
+ torch .nn .utils .clip_grad_norm_ (policy .parameters (), clip_grad )
101
+
102
+ transformer_optim .step ()
103
+ temperature_optim .step ()
104
+
105
+ return loss_vals .detach ()
106
+
107
+ if cfg .compile .compile :
108
+ compile_mode = cfg .compile .compile_mode
109
+ if compile_mode in ("" , None ):
110
+ compile_mode = "default"
111
+ update = torch .compile (update , mode = compile_mode , dynamic = False )
112
+ if cfg .compile .cudagraphs :
113
+ warnings .warn (
114
+ "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution." ,
115
+ category = UserWarning ,
116
+ )
117
+ if cfg .optim .optimizer == "lamb" :
118
+ raise ValueError (
119
+ "cudagraphs isn't compatible with the Lamb optimizer. Use optim.optimizer=Adam instead."
120
+ )
121
+ update = CudaGraphModule (update , warmup = 50 )
122
+
90
123
pbar = tqdm .tqdm (total = cfg .optim .pretrain_gradient_steps )
91
124
92
125
pretrain_gradient_steps = cfg .optim .pretrain_gradient_steps
@@ -100,35 +133,29 @@ def main(cfg: "DictConfig"): # noqa: F821
100
133
start_time = time .time ()
101
134
for i in range (pretrain_gradient_steps ):
102
135
pbar .update (1 )
103
- # Sample data
104
- data = offline_buffer .sample ()
105
- # Compute loss
106
- loss_vals = loss_module (data .to (model_device ))
107
- transformer_loss = loss_vals ["loss_log_likelihood" ] + loss_vals ["loss_entropy" ]
108
- temperature_loss = loss_vals ["loss_alpha" ]
109
-
110
- transformer_optim .zero_grad ()
111
- torch .nn .utils .clip_grad_norm_ (policy .parameters (), clip_grad )
112
- transformer_loss .backward ()
113
- transformer_optim .step ()
136
+ with timeit ("sample" ):
137
+ # Sample data
138
+ data = offline_buffer .sample ()
114
139
115
- temperature_optim . zero_grad ()
116
- temperature_loss . backward ()
117
- temperature_optim . step ( )
140
+ with timeit ( "update" ):
141
+ torch . compiler . cudagraph_mark_step_begin ()
142
+ loss_vals = update ( data . to ( model_device ) )
118
143
119
144
scheduler .step ()
120
145
121
146
# Log metrics
122
147
to_log = {
123
- "train/loss_log_likelihood" : loss_vals ["loss_log_likelihood" ]. item () ,
124
- "train/loss_entropy" : loss_vals ["loss_entropy" ]. item () ,
125
- "train/loss_alpha" : loss_vals ["loss_alpha" ]. item () ,
126
- "train/alpha" : loss_vals ["alpha" ]. item () ,
127
- "train/entropy" : loss_vals ["entropy" ]. item () ,
148
+ "train/loss_log_likelihood" : loss_vals ["loss_log_likelihood" ],
149
+ "train/loss_entropy" : loss_vals ["loss_entropy" ],
150
+ "train/loss_alpha" : loss_vals ["loss_alpha" ],
151
+ "train/alpha" : loss_vals ["alpha" ],
152
+ "train/entropy" : loss_vals ["entropy" ],
128
153
}
129
154
130
155
# Evaluation
131
- with torch .no_grad (), set_exploration_type (ExplorationType .DETERMINISTIC ):
156
+ with torch .no_grad (), set_exploration_type (
157
+ ExplorationType .DETERMINISTIC
158
+ ), timeit ("eval" ):
132
159
inference_policy .eval ()
133
160
if i % pretrain_log_interval == 0 :
134
161
eval_td = test_env .rollout (
@@ -143,6 +170,11 @@ def main(cfg: "DictConfig"): # noqa: F821
143
170
eval_td ["next" , "reward" ].sum (1 ).mean ().item () / reward_scaling
144
171
)
145
172
173
+ if i % 200 == 0 :
174
+ to_log .update (timeit .todict (prefix = "time" ))
175
+ timeit .print ()
176
+ timeit .erase ()
177
+
146
178
if logger is not None :
147
179
log_metrics (logger , to_log , i )
148
180
0 commit comments