1717import time
1818from contextlib import nullcontext
1919from pprint import pformat
20- from typing import Any
20+ from typing import Any , Callable
2121
2222import torch
2323from termcolor import colored
4646 get_safe_torch_device ,
4747 has_method ,
4848 init_logging ,
49+ get_accelerate_config ,
50+ is_launched_with_accelerate
4951)
5052from lerobot .common .utils .wandb_utils import WandBLogger
5153from lerobot .configs import parser
@@ -63,40 +65,55 @@ def update_policy(
6365 lr_scheduler = None ,
6466 use_amp : bool = False ,
6567 lock = None ,
68+ accelerator : Callable = None ,
6669) -> tuple [MetricsTracker , dict ]:
6770 start_time = time .perf_counter ()
6871 device = get_device_from_parameters (policy )
6972 policy .train ()
70- with torch .autocast (device_type = device .type ) if use_amp else nullcontext ():
73+ with torch .autocast (device_type = device .type ) if use_amp and accelerator is None else nullcontext ():
7174 loss , output_dict = policy .forward (batch )
7275 # TODO(rcadene): policy.unnormalize_outputs(out_dict)
73- grad_scaler .scale (loss ).backward ()
7476
75- # Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
76- grad_scaler .unscale_ (optimizer )
77-
78- grad_norm = torch .nn .utils .clip_grad_norm_ (
79- policy .parameters (),
80- grad_clip_norm ,
81- error_if_nonfinite = False ,
82- )
77+ if accelerator :
78+ accelerator .backward (loss )
79+ accelerator .unscale_gradients (optimizer = optimizer )
80+ grad_norm = torch .nn .utils .clip_grad_norm_ (
81+ policy .parameters (),
82+ grad_clip_norm ,
83+ error_if_nonfinite = False ,
84+ )
85+ optimizer .step ()
86+ else :
87+ grad_scaler .scale (loss ).backward ()
88+ # Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
89+ grad_scaler .unscale_ (optimizer )
90+
91+ grad_norm = torch .nn .utils .clip_grad_norm_ (
92+ policy .parameters (),
93+ grad_clip_norm ,
94+ error_if_nonfinite = False ,
95+ )
8396
84- # Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
85- # although it still skips optimizer.step() if the gradients contain infs or NaNs.
86- with lock if lock is not None else nullcontext ():
87- grad_scaler .step (optimizer )
88- # Updates the scale for next iteration.
89- grad_scaler .update ()
97+ # Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
98+ # although it still skips optimizer.step() if the gradients contain infs or NaNs.
99+ with lock if lock is not None else nullcontext ():
100+ grad_scaler .step (optimizer )
101+ # Updates the scale for next iteration.
102+ grad_scaler .update ()
90103
91104 optimizer .zero_grad ()
92105
93106 # Step through pytorch scheduler at every batch instead of epoch
94107 if lr_scheduler is not None :
95108 lr_scheduler .step ()
96109
97- if has_method (policy , "update" ):
98- # To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
99- policy .update ()
110+ if accelerator :
111+ if has_method (accelerator .unwrap_model (policy , keep_fp32_wrapper = True ), "update" ): # FIXME(mshukor): avoid accelerator.unwrap_model ?
112+ accelerator .unwrap_model (policy , keep_fp32_wrapper = True ).update ()
113+ else :
114+ if has_method (policy , "update" ):
115+ # To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
116+ policy .update ()
100117
101118 train_metrics .loss = loss .item ()
102119 train_metrics .grad_norm = grad_norm .item ()
@@ -106,21 +123,25 @@ def update_policy(
106123
107124
108125@parser .wrap ()
109- def train (cfg : TrainPipelineConfig ):
126+ def train (cfg : TrainPipelineConfig , accelerator : Callable = None ):
110127 cfg .validate ()
111128 logging .info (pformat (cfg .to_dict ()))
112129
130+ if accelerator and not accelerator .is_main_process :
131+ # Disable logging on non-main processes.
132+ cfg .wandb .enable = False
133+
113134 if cfg .wandb .enable and cfg .wandb .project :
114135 wandb_logger = WandBLogger (cfg )
115136 else :
116137 wandb_logger = None
117138 logging .info (colored ("Logs will be saved locally." , "yellow" , attrs = ["bold" ]))
118139
119140 if cfg .seed is not None :
120- set_seed (cfg .seed )
141+ set_seed (cfg .seed , accelerator = accelerator )
121142
122143 # Check device is available
123- device = get_safe_torch_device (cfg .device , log = True )
144+ device = get_safe_torch_device (cfg .device , log = True , accelerator = accelerator )
124145 torch .backends .cudnn .benchmark = True
125146 torch .backends .cuda .matmul .allow_tf32 = True
126147
@@ -141,7 +162,7 @@ def train(cfg: TrainPipelineConfig):
141162 device = device ,
142163 ds_meta = dataset .meta ,
143164 )
144-
165+ policy . to ( device )
145166 logging .info ("Creating optimizer and scheduler" )
146167 optimizer , lr_scheduler = make_optimizer_and_scheduler (cfg , policy )
147168 grad_scaler = GradScaler (device , enabled = cfg .use_amp )
@@ -184,6 +205,10 @@ def train(cfg: TrainPipelineConfig):
184205 pin_memory = device .type != "cpu" ,
185206 drop_last = False ,
186207 )
208+ if accelerator :
209+ policy , optimizer , dataloader , lr_scheduler = accelerator .prepare (
210+ policy , optimizer , dataloader , lr_scheduler
211+ )
187212 dl_iter = cycle (dataloader )
188213
189214 policy .train ()
@@ -197,7 +222,7 @@ def train(cfg: TrainPipelineConfig):
197222 }
198223
199224 train_tracker = MetricsTracker (
200- cfg .batch_size , dataset .num_frames , dataset .num_episodes , train_metrics , initial_step = step
225+ cfg .batch_size , dataset .num_frames , dataset .num_episodes , train_metrics , initial_step = step , accelerator = accelerator
201226 )
202227
203228 logging .info ("Start offline training on a fixed dataset" )
@@ -219,6 +244,7 @@ def train(cfg: TrainPipelineConfig):
219244 grad_scaler = grad_scaler ,
220245 lr_scheduler = lr_scheduler ,
221246 use_amp = cfg .use_amp ,
247+ accelerator = accelerator ,
222248 )
223249
224250 # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
@@ -238,21 +264,26 @@ def train(cfg: TrainPipelineConfig):
238264 wandb_logger .log_dict (wandb_log_dict , step )
239265 train_tracker .reset_averages ()
240266
241- if cfg .save_checkpoint and is_saving_step :
267+ if cfg .save_checkpoint and is_saving_step and ( not accelerator or accelerator . is_main_process ) :
242268 logging .info (f"Checkpoint policy after step { step } " )
243269 checkpoint_dir = get_step_checkpoint_dir (cfg .output_dir , cfg .steps , step )
244- save_checkpoint (checkpoint_dir , step , cfg , policy , optimizer , lr_scheduler )
270+ save_checkpoint (checkpoint_dir , step , cfg , policy if not accelerator else accelerator . unwrap_model ( policy ) , optimizer , lr_scheduler )
245271 update_last_checkpoint (checkpoint_dir )
246272 if wandb_logger :
247273 wandb_logger .log_policy (checkpoint_dir )
248274
275+ if accelerator :
276+ accelerator .wait_for_everyone ()
249277 if cfg .env and is_eval_step :
250278 step_id = get_step_identifier (step , cfg .steps )
251279 logging .info (f"Eval policy at step { step } " )
252- with torch .no_grad (), torch .autocast (device_type = device .type ) if cfg .use_amp else nullcontext ():
280+ with (
281+ torch .no_grad (),
282+ torch .autocast (device_type = device .type ) if cfg .use_amp and not accelerator else nullcontext (),
283+ ):
253284 eval_info = eval_policy (
254285 eval_env ,
255- policy ,
286+ policy if not accelerator else accelerator . unwrap_model ( policy ) ,
256287 cfg .eval .n_episodes ,
257288 videos_dir = cfg .output_dir / "eval" / f"videos_step_{ step_id } " ,
258289 max_episodes_rendered = 4 ,
@@ -265,7 +296,7 @@ def train(cfg: TrainPipelineConfig):
265296 "eval_s" : AverageMeter ("eval_s" , ":.3f" ),
266297 }
267298 eval_tracker = MetricsTracker (
268- cfg .batch_size , dataset .num_frames , dataset .num_episodes , eval_metrics , initial_step = step
299+ cfg .batch_size , dataset .num_frames , dataset .num_episodes , eval_metrics , initial_step = step , accelerator = None
269300 )
270301 eval_tracker .eval_s = eval_info ["aggregated" ].pop ("eval_s" )
271302 eval_tracker .avg_sum_reward = eval_info ["aggregated" ].pop ("avg_sum_reward" )
@@ -283,4 +314,11 @@ def train(cfg: TrainPipelineConfig):
283314
284315if __name__ == "__main__" :
285316 init_logging ()
286- train ()
317+ if is_launched_with_accelerate ():
318+ import accelerate
319+ # We set step_scheduler_with_optimizer False to prevent accelerate from
320+ # adjusting the lr_scheduler steps based on the num_processes
321+ accelerator = accelerate .Accelerator (step_scheduler_with_optimizer = False )
322+ train (accelerator = accelerator )
323+ else :
324+ train ()
0 commit comments