Skip to content

Commit 86e8c1d

Browse files
committed
supporting accelerate
1 parent 8699a28 commit 86e8c1d

File tree

4 files changed

+102
-42
lines changed

4 files changed

+102
-42
lines changed

lerobot/common/utils/logging_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16-
from typing import Any
16+
from typing import Any, Callable
1717

1818
from lerobot.common.utils.utils import format_big_number
1919

@@ -93,12 +93,14 @@ def __init__(
9393
num_episodes: int,
9494
metrics: dict[str, AverageMeter],
9595
initial_step: int = 0,
96+
accelerator: Callable = None,
9697
):
9798
self.__dict__.update({k: None for k in self.__keys__})
9899
self._batch_size = batch_size
99100
self._num_frames = num_frames
100101
self._avg_samples_per_ep = num_frames / num_episodes
101102
self.metrics = metrics
103+
self.accelerator = accelerator
102104

103105
self.steps = initial_step
104106
# A sample is an (observation,action) pair, where observation and action
@@ -128,7 +130,7 @@ def step(self) -> None:
128130
Updates metrics that depend on 'step' for one step.
129131
"""
130132
self.steps += 1
131-
self.samples += self._batch_size
133+
self.samples += self._batch_size * (self.accelerator.num_processes if self.accelerator else 1)
132134
self.episodes = self.samples / self._avg_samples_per_ep
133135
self.epochs = self.samples / self._num_frames
134136

lerobot/common/utils/random_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import random
1717
from contextlib import contextmanager
1818
from pathlib import Path
19-
from typing import Any, Generator
19+
from typing import Any, Generator, Callable
2020

2121
import numpy as np
2222
import torch
@@ -163,14 +163,16 @@ def set_rng_state(random_state_dict: dict[str, Any]):
163163
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
164164

165165

166-
def set_seed(seed) -> None:
166+
def set_seed(seed, accelerator: Callable = None) -> None:
167167
"""Set seed for reproducibility."""
168168
random.seed(seed)
169169
np.random.seed(seed)
170170
torch.manual_seed(seed)
171171
if torch.cuda.is_available():
172172
torch.cuda.manual_seed_all(seed)
173-
173+
if accelerator:
174+
from accelerate.utils import set_seed
175+
set_seed(seed)
174176

175177
@contextmanager
176178
def seeded_context(seed: int) -> Generator[None, None, None]:

lerobot/common/utils/utils.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020
from copy import copy
2121
from datetime import datetime, timezone
2222
from pathlib import Path
23-
23+
from typing import Callable
2424
import numpy as np
2525
import torch
26-
26+
from typing import Any
2727

2828
def none_or_int(value):
2929
if value == "None":
@@ -50,12 +50,12 @@ def auto_select_torch_device() -> torch.device:
5050
return torch.device("cpu")
5151

5252

53-
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
53+
def get_safe_torch_device(try_device: str, log: bool = False, accelerator: Callable = None) -> torch.device:
5454
"""Given a string, return a torch.device with checks on whether the device is available."""
5555
match try_device:
5656
case "cuda":
5757
assert torch.cuda.is_available()
58-
device = torch.device("cuda")
58+
device = accelerator.device if accelerator else torch.device("cuda")
5959
case "mps":
6060
assert torch.backends.mps.is_available()
6161
device = torch.device("mps")
@@ -103,7 +103,7 @@ def is_amp_available(device: str):
103103
raise ValueError(f"Unknown device '{device}.")
104104

105105

106-
def init_logging():
106+
def init_logging(accelerator: Callable = None):
107107
def custom_format(record):
108108
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
109109
fnameline = f"{record.pathname}:{record.lineno}"
@@ -120,7 +120,10 @@ def custom_format(record):
120120
console_handler = logging.StreamHandler()
121121
console_handler.setFormatter(formatter)
122122
logging.getLogger().addHandler(console_handler)
123-
123+
if accelerator is not None and not accelerator.is_main_process:
124+
# Disable duplicate logging on non-main processes
125+
logging.info(f"Setting logging level on non-main process {accelerator.process_index} to WARNING.")
126+
logging.getLogger().setLevel(logging.WARNING)
124127

125128
def format_big_number(num, precision=0):
126129
suffixes = ["", "K", "M", "B", "T", "Q"]
@@ -216,3 +219,18 @@ def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
216219
except TypeError:
217220
# If a TypeError is raised, the string is not a valid dtype
218221
return False
222+
223+
def is_launched_with_accelerate() -> bool:
224+
return "ACCELERATE_MIXED_PRECISION" in os.environ
225+
226+
def get_accelerate_config(accelerator: Callable = None) -> dict[str, Any]:
227+
config = {}
228+
if not accelerator:
229+
return config
230+
config["num_processes"] = accelerator.num_processes
231+
config["device"] = str(accelerator.device)
232+
config["distributed_type"] = str(accelerator.distributed_type)
233+
config["mixed_precision"] = accelerator.mixed_precision
234+
config["gradient_accumulation_steps"] = accelerator.gradient_accumulation_steps
235+
236+
return config

lerobot/scripts/train.py

Lines changed: 69 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import time
1818
from contextlib import nullcontext
1919
from pprint import pformat
20-
from typing import Any
20+
from typing import Any, Callable
2121

2222
import torch
2323
from termcolor import colored
@@ -46,6 +46,8 @@
4646
get_safe_torch_device,
4747
has_method,
4848
init_logging,
49+
get_accelerate_config,
50+
is_launched_with_accelerate
4951
)
5052
from lerobot.common.utils.wandb_utils import WandBLogger
5153
from lerobot.configs import parser
@@ -63,40 +65,55 @@ def update_policy(
6365
lr_scheduler=None,
6466
use_amp: bool = False,
6567
lock=None,
68+
accelerator: Callable = None,
6669
) -> tuple[MetricsTracker, dict]:
6770
start_time = time.perf_counter()
6871
device = get_device_from_parameters(policy)
6972
policy.train()
70-
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
73+
with torch.autocast(device_type=device.type) if use_amp and accelerator is None else nullcontext():
7174
loss, output_dict = policy.forward(batch)
7275
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
73-
grad_scaler.scale(loss).backward()
7476

75-
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
76-
grad_scaler.unscale_(optimizer)
77-
78-
grad_norm = torch.nn.utils.clip_grad_norm_(
79-
policy.parameters(),
80-
grad_clip_norm,
81-
error_if_nonfinite=False,
82-
)
77+
if accelerator:
78+
accelerator.backward(loss)
79+
accelerator.unscale_gradients(optimizer=optimizer)
80+
grad_norm = torch.nn.utils.clip_grad_norm_(
81+
policy.parameters(),
82+
grad_clip_norm,
83+
error_if_nonfinite=False,
84+
)
85+
optimizer.step()
86+
else:
87+
grad_scaler.scale(loss).backward()
88+
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
89+
grad_scaler.unscale_(optimizer)
90+
91+
grad_norm = torch.nn.utils.clip_grad_norm_(
92+
policy.parameters(),
93+
grad_clip_norm,
94+
error_if_nonfinite=False,
95+
)
8396

84-
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
85-
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
86-
with lock if lock is not None else nullcontext():
87-
grad_scaler.step(optimizer)
88-
# Updates the scale for next iteration.
89-
grad_scaler.update()
97+
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
98+
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
99+
with lock if lock is not None else nullcontext():
100+
grad_scaler.step(optimizer)
101+
# Updates the scale for next iteration.
102+
grad_scaler.update()
90103

91104
optimizer.zero_grad()
92105

93106
# Step through pytorch scheduler at every batch instead of epoch
94107
if lr_scheduler is not None:
95108
lr_scheduler.step()
96109

97-
if has_method(policy, "update"):
98-
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
99-
policy.update()
110+
if accelerator:
111+
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"): # FIXME(mshukor): avoid accelerator.unwrap_model ?
112+
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()
113+
else:
114+
if has_method(policy, "update"):
115+
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
116+
policy.update()
100117

101118
train_metrics.loss = loss.item()
102119
train_metrics.grad_norm = grad_norm.item()
@@ -106,21 +123,25 @@ def update_policy(
106123

107124

108125
@parser.wrap()
109-
def train(cfg: TrainPipelineConfig):
126+
def train(cfg: TrainPipelineConfig, accelerator: Callable = None):
110127
cfg.validate()
111128
logging.info(pformat(cfg.to_dict()))
112129

130+
if accelerator and not accelerator.is_main_process:
131+
# Disable logging on non-main processes.
132+
cfg.wandb.enable = False
133+
113134
if cfg.wandb.enable and cfg.wandb.project:
114135
wandb_logger = WandBLogger(cfg)
115136
else:
116137
wandb_logger = None
117138
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
118139

119140
if cfg.seed is not None:
120-
set_seed(cfg.seed)
141+
set_seed(cfg.seed, accelerator=accelerator)
121142

122143
# Check device is available
123-
device = get_safe_torch_device(cfg.device, log=True)
144+
device = get_safe_torch_device(cfg.device, log=True, accelerator=accelerator)
124145
torch.backends.cudnn.benchmark = True
125146
torch.backends.cuda.matmul.allow_tf32 = True
126147

@@ -141,7 +162,7 @@ def train(cfg: TrainPipelineConfig):
141162
device=device,
142163
ds_meta=dataset.meta,
143164
)
144-
165+
policy.to(device)
145166
logging.info("Creating optimizer and scheduler")
146167
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
147168
grad_scaler = GradScaler(device, enabled=cfg.use_amp)
@@ -184,6 +205,10 @@ def train(cfg: TrainPipelineConfig):
184205
pin_memory=device.type != "cpu",
185206
drop_last=False,
186207
)
208+
if accelerator:
209+
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
210+
policy, optimizer, dataloader, lr_scheduler
211+
)
187212
dl_iter = cycle(dataloader)
188213

189214
policy.train()
@@ -197,7 +222,7 @@ def train(cfg: TrainPipelineConfig):
197222
}
198223

199224
train_tracker = MetricsTracker(
200-
cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step
225+
cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step, accelerator=accelerator
201226
)
202227

203228
logging.info("Start offline training on a fixed dataset")
@@ -219,6 +244,7 @@ def train(cfg: TrainPipelineConfig):
219244
grad_scaler=grad_scaler,
220245
lr_scheduler=lr_scheduler,
221246
use_amp=cfg.use_amp,
247+
accelerator=accelerator,
222248
)
223249

224250
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
@@ -238,21 +264,26 @@ def train(cfg: TrainPipelineConfig):
238264
wandb_logger.log_dict(wandb_log_dict, step)
239265
train_tracker.reset_averages()
240266

241-
if cfg.save_checkpoint and is_saving_step:
267+
if cfg.save_checkpoint and is_saving_step and (not accelerator or accelerator.is_main_process):
242268
logging.info(f"Checkpoint policy after step {step}")
243269
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
244-
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
270+
save_checkpoint(checkpoint_dir, step, cfg, policy if not accelerator else accelerator.unwrap_model(policy), optimizer, lr_scheduler)
245271
update_last_checkpoint(checkpoint_dir)
246272
if wandb_logger:
247273
wandb_logger.log_policy(checkpoint_dir)
248274

275+
if accelerator:
276+
accelerator.wait_for_everyone()
249277
if cfg.env and is_eval_step:
250278
step_id = get_step_identifier(step, cfg.steps)
251279
logging.info(f"Eval policy at step {step}")
252-
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
280+
with (
281+
torch.no_grad(),
282+
torch.autocast(device_type=device.type) if cfg.use_amp and not accelerator else nullcontext(),
283+
):
253284
eval_info = eval_policy(
254285
eval_env,
255-
policy,
286+
policy if not accelerator else accelerator.unwrap_model(policy),
256287
cfg.eval.n_episodes,
257288
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
258289
max_episodes_rendered=4,
@@ -265,7 +296,7 @@ def train(cfg: TrainPipelineConfig):
265296
"eval_s": AverageMeter("eval_s", ":.3f"),
266297
}
267298
eval_tracker = MetricsTracker(
268-
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
299+
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step, accelerator=None
269300
)
270301
eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s")
271302
eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward")
@@ -283,4 +314,11 @@ def train(cfg: TrainPipelineConfig):
283314

284315
if __name__ == "__main__":
285316
init_logging()
286-
train()
317+
if is_launched_with_accelerate():
318+
import accelerate
319+
# We set step_scheduler_with_optimizer False to prevent accelerate from
320+
# adjusting the lr_scheduler steps based on the num_processes
321+
accelerator = accelerate.Accelerator(step_scheduler_with_optimizer=False)
322+
train(accelerator=accelerator)
323+
else:
324+
train()

0 commit comments

Comments
 (0)