Skip to content

Commit 4f08f96

Browse files
committed
add comments for addressing feedback
1 parent 6505e36 commit 4f08f96

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

src/accelerate/accelerator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1627,6 +1627,10 @@ def _get_tensor_address(p):
16271627
for obj in result:
16281628
if isinstance(obj, torch.optim.Optimizer):
16291629
for param_group in obj.param_groups:
1630+
# Each param_group originally maps to model parameters (e.g., from model.parameters()).
1631+
# After _prepare_tp(), parameter references are replaced with DTensor instances.
1632+
# Therefore, we remap the parameter references to their new DTensor addresses
1633+
# so that the optimizer can correctly update the model parameters.
16301634
param_group["params"] = [mapping[_get_tensor_address(p)] for p in param_group["params"]]
16311635

16321636
return args

src/accelerate/utils/fsdp_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,9 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype):
507507
device_mesh = sharded_param.device_mesh
508508
full_param = full_param.detach().to(device_mesh.device_type)
509509
if isinstance(full_param, DTensor):
510+
# dist.broadcast() only supports torch.Tensor.
511+
# After prepare_tp(), model parameters may become DTensor.
512+
# To broadcast such a parameter, convert it to a local tensor first.
510513
full_param = full_param.to_local()
511514
dist.broadcast(full_param, src=0, group=dist.group.WORLD)
512515
sharded_tensor = distribute_tensor(full_param, device_mesh, sharded_param.placements)

0 commit comments

Comments
 (0)