Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions examples/torch_native_parallelism/nd_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
PerformanceTracker,
create_collate_fn,
get_dataset,
get_model_flops_per_token,
setup_tokenizer,
)

Expand Down Expand Up @@ -73,7 +74,7 @@ def forward(model, batch, optimizer, accelerator: Accelerator):
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
optimizer.zero_grad(set_to_none=False)
dist.all_reduce(loss, op=dist.ReduceOp.AVG, group=loss_reduce_grp)

return loss
Expand Down Expand Up @@ -123,6 +124,7 @@ def train(args):

total_num_steps = min(args.num_steps, len(dataloader))
performance_tracker = PerformanceTracker(warmup_steps=5)
model_flops_per_token = get_model_flops_per_token(model, args.sequence_length)

accelerator.print("Starting training...")
for step, batch in enumerate(dataloader):
Expand All @@ -132,7 +134,9 @@ def train(args):
loss = forward(model, batch, optimizer, accelerator)

# We report TPS per device, so we divide by the number of devices in the non-data parallel dimension
metrics = performance_tracker.step(batch["input_ids"].shape[1] / parallelism_config.non_data_parallel_size)
metrics = performance_tracker.step(
batch["input_ids"].shape[1] / parallelism_config.non_data_parallel_size, model_flops_per_token
)

print_msg = f"Step {step}/{total_num_steps}, Loss: {loss.item():.4f}"
if "warmup_completed" in metrics:
Expand Down
8 changes: 1 addition & 7 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1570,13 +1570,7 @@ def prepare(self, *args, device_placement=None):
return result if len(result) > 1 else result[0]

def _prepare_tp(self, *args):
# First pass: prepare everything except schedulers (and model, which is prepared separately below)
result = [
self._prepare_one(obj, first_pass=True) if not isinstance(obj, torch.nn.Module) else obj for obj in args
]

# Second pass: prepare schedulers
result = [self._prepare_one(obj) if not isinstance(obj, torch.nn.Module) else obj for obj in result]
result = list(args)

device_mesh = self.torch_device_mesh

Expand Down
Loading