Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions examples/torch_native_parallelism/nd_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
PerformanceTracker,
create_collate_fn,
get_dataset,
get_model_flops_per_token,
setup_tokenizer,
)

Expand Down Expand Up @@ -73,7 +74,7 @@ def forward(model, batch, optimizer, accelerator: Accelerator):
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
optimizer.zero_grad(set_to_none=False)
dist.all_reduce(loss, op=dist.ReduceOp.AVG, group=loss_reduce_grp)

return loss
Expand Down Expand Up @@ -123,6 +124,7 @@ def train(args):

total_num_steps = min(args.num_steps, len(dataloader))
performance_tracker = PerformanceTracker(warmup_steps=5)
model_flops_per_token = get_model_flops_per_token(model, args.sequence_length)

accelerator.print("Starting training...")
for step, batch in enumerate(dataloader):
Expand All @@ -132,7 +134,9 @@ def train(args):
loss = forward(model, batch, optimizer, accelerator)

# We report TPS per device, so we divide by the number of devices in the non-data parallel dimension
metrics = performance_tracker.step(batch["input_ids"].shape[1] / parallelism_config.non_data_parallel_size)
metrics = performance_tracker.step(
batch["input_ids"].shape[1] / parallelism_config.non_data_parallel_size, model_flops_per_token
)

print_msg = f"Step {step}/{total_num_steps}, Loss: {loss.item():.4f}"
if "warmup_completed" in metrics:
Expand Down
9 changes: 8 additions & 1 deletion src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1623,7 +1623,14 @@ def _get_tensor_address(p):
# so that the optimizer can correctly update the model parameters.
param_group["params"] = [mapping[_get_tensor_address(p)] for p in param_group["params"]]

return args
for item in result:
if any(
item in container
for container in (self._dataloaders, self._models, self._optimizers, self._schedulers)
):
item._is_accelerate_prepared = True

return result

def _prepare_cp(self, *args):
from torch.distributed.tensor.experimental import context_parallel
Expand Down