diff --git a/examples/torch_native_parallelism/nd_parallel.py b/examples/torch_native_parallelism/nd_parallel.py index e48f6662091..0dd23c8e39a 100644 --- a/examples/torch_native_parallelism/nd_parallel.py +++ b/examples/torch_native_parallelism/nd_parallel.py @@ -31,6 +31,7 @@ PerformanceTracker, create_collate_fn, get_dataset, + get_model_flops_per_token, setup_tokenizer, ) @@ -73,7 +74,7 @@ def forward(model, batch, optimizer, accelerator: Accelerator): loss = outputs.loss accelerator.backward(loss) optimizer.step() - optimizer.zero_grad() + optimizer.zero_grad(set_to_none=False) dist.all_reduce(loss, op=dist.ReduceOp.AVG, group=loss_reduce_grp) return loss @@ -123,6 +124,7 @@ def train(args): total_num_steps = min(args.num_steps, len(dataloader)) performance_tracker = PerformanceTracker(warmup_steps=5) + model_flops_per_token = get_model_flops_per_token(model, args.sequence_length) accelerator.print("Starting training...") for step, batch in enumerate(dataloader): @@ -132,7 +134,9 @@ def train(args): loss = forward(model, batch, optimizer, accelerator) # We report TPS per device, so we divide by the number of devices in the non-data parallel dimension - metrics = performance_tracker.step(batch["input_ids"].shape[1] / parallelism_config.non_data_parallel_size) + metrics = performance_tracker.step( + batch["input_ids"].shape[1] / parallelism_config.non_data_parallel_size, model_flops_per_token + ) print_msg = f"Step {step}/{total_num_steps}, Loss: {loss.item():.4f}" if "warmup_completed" in metrics: diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 2f313ca508f..962f07bbfac 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -1528,6 +1528,12 @@ def prepare(self, *args, device_placement=None): if self.parallelism_config and self.parallelism_config.tp_enabled: args = self._prepare_tp(*args) + for item in args: + if any( + item in container + for container in (self._dataloaders, self._models, self._optimizers, self._schedulers) + ): + item._is_accelerate_prepared = True if self.parallelism_config and self.parallelism_config.cp_enabled: args = self._prepare_cp(*args) @@ -1623,7 +1629,7 @@ def _get_tensor_address(p): # so that the optimizer can correctly update the model parameters. param_group["params"] = [mapping[_get_tensor_address(p)] for p in param_group["params"]] - return args + return result def _prepare_cp(self, *args): from torch.distributed.tensor.experimental import context_parallel