Skip to content

Commit 54ecc98

Browse files
mori360can-gaa-hou
authored andcommitted
[FSDP] Use post_reduce_stream.record_event() on hsdp+cpuoffload (pytorch#160481)
Fixes pytorch#160291 `post_reduce_stream` is `all_reduce_stream` during HSDP, but CPU-GPU sync is hard coded to `reduce_scatter_stream` The hard-code could fail unit test on HSDP+CPU offload, add unit test here. Pull Request resolved: pytorch#160481 Approved by: https://github.com/weifengpy
1 parent bc2615d commit 54ecc98

File tree

2 files changed

+23
-12
lines changed

2 files changed

+23
-12
lines changed

test/distributed/_composable/fsdp/test_fully_shard_training.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ def test_train_parity_multi_group(self):
335335
self.run_subtests(
336336
{
337337
"reshard_after_forward": [True, False, 2],
338-
"device_type": [device_type.type],
338+
"test_device_type": [device_type.type],
339339
"offload_policy": [OffloadPolicy()],
340340
"delay_after_forward": [False, True],
341341
"delay_before_all_gather": [False, True],
@@ -360,7 +360,7 @@ def test_train_parity_multi_group_cpu_offload_eager(self):
360360
CPUOffloadPolicy(pin_memory=True),
361361
CPUOffloadPolicy(pin_memory=False),
362362
],
363-
"device_type": [device_type.type],
363+
"test_device_type": [device_type.type],
364364
"delay_after_forward": [False, True],
365365
"delay_before_all_gather": [False, True],
366366
"delay_before_reduce_scatter": [False, True],
@@ -381,7 +381,7 @@ def test_train_parity_multi_group_unshard_async_op(self):
381381
self.run_subtests(
382382
{
383383
"reshard_after_forward": [True],
384-
"device_type": [device_type.type],
384+
"test_device_type": [device_type.type],
385385
"offload_policy": [OffloadPolicy()],
386386
"delay_after_forward": [False, True],
387387
"delay_before_all_gather": [False, True],
@@ -396,7 +396,7 @@ def _test_train_parity_multi_group(
396396
self,
397397
reshard_after_forward: Union[bool, int],
398398
offload_policy: OffloadPolicy,
399-
device_type: str,
399+
test_device_type: str,
400400
delay_after_forward: bool,
401401
delay_before_all_gather: bool,
402402
delay_before_reduce_scatter: bool,
@@ -412,7 +412,7 @@ def _test_train_parity_multi_group(
412412
in (2, 3)
413413
):
414414
return
415-
assert device_type in ("cuda", "hpu", "xpu", "cpu"), f"{device_type}"
415+
assert test_device_type in ("cuda", "hpu", "xpu", "cpu"), f"{test_device_type}"
416416
torch.manual_seed(42)
417417
vocab_size = 1024
418418
model_args = ModelArgs(
@@ -424,7 +424,7 @@ def _test_train_parity_multi_group(
424424
)
425425
model = Transformer(model_args)
426426
ref_model = copy.deepcopy(model)
427-
if device_type == device_type:
427+
if test_device_type == device_type.type:
428428
replicate(
429429
ref_model.to(device_type),
430430
device_ids=[self.rank],
@@ -433,7 +433,7 @@ def _test_train_parity_multi_group(
433433
gloo_pg = dist.new_group(backend="gloo")
434434
replicate(ref_model, process_group=gloo_pg)
435435
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
436-
mesh = init_device_mesh(device_type, (self.world_size,))
436+
mesh = init_device_mesh(test_device_type, (self.world_size,))
437437
fully_shard_fn = functools.partial(
438438
fully_shard,
439439
mesh=mesh,
@@ -483,12 +483,12 @@ def delayed_reduce_scatter(*args, **kwargs):
483483
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
484484
losses.append(_model(inp).sum())
485485
if _model is model and delay_after_forward:
486-
torch.get_device_module(device_type)._sleep(
486+
torch.get_device_module(test_device_type)._sleep(
487487
int(delay_in_ms * get_cycles_per_ms())
488488
)
489489
losses[-1].backward()
490490
if _model is model and delay_before_optim:
491-
torch.get_device_module(device_type)._sleep(
491+
torch.get_device_module(test_device_type)._sleep(
492492
int(delay_in_ms * get_cycles_per_ms())
493493
)
494494
_optim.step()
@@ -1360,6 +1360,10 @@ def test_train_parity_hsdp(self):
13601360
"use_activation_checkpointing": [False, True],
13611361
"mlp_dim": [3, 16, 17],
13621362
"sync_gradients_at_last_batch": [True, False],
1363+
"offload_policy": [
1364+
CPUOffloadPolicy(pin_memory=True),
1365+
CPUOffloadPolicy(pin_memory=False),
1366+
],
13631367
},
13641368
functools.partial(self._test_train_parity_hsdp, global_mesh),
13651369
)
@@ -1371,6 +1375,7 @@ def _test_train_parity_hsdp(
13711375
use_activation_checkpointing: bool,
13721376
mlp_dim: int,
13731377
sync_gradients_at_last_batch: bool,
1378+
offload_policy: CPUOffloadPolicy,
13741379
):
13751380
torch.manual_seed(42)
13761381
model = nn.Sequential(
@@ -1389,10 +1394,16 @@ def _test_train_parity_hsdp(
13891394
if use_activation_checkpointing:
13901395
checkpoint(mlp)
13911396
fully_shard(
1392-
mlp, mesh=global_mesh, reshard_after_forward=reshard_after_forward
1397+
mlp,
1398+
mesh=global_mesh,
1399+
reshard_after_forward=reshard_after_forward,
1400+
offload_policy=offload_policy,
13931401
)
13941402
fully_shard(
1395-
model, mesh=global_mesh, reshard_after_forward=reshard_after_forward
1403+
model,
1404+
mesh=global_mesh,
1405+
reshard_after_forward=reshard_after_forward,
1406+
offload_policy=offload_policy,
13961407
)
13971408
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
13981409
check_sharded_parity(self, ref_model, model)

torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,7 @@ def foreach_reduce(
628628
if non_blocking:
629629
# Record an event on which to block the CPU thread to
630630
# ensure that the D2H copy finishes before the optimizer
631-
fsdp_param.grad_offload_event = reduce_scatter_stream.record_event()
631+
fsdp_param.grad_offload_event = post_reduce_stream.record_event()
632632
if to_accumulate_grad:
633633
assert isinstance(fsdp_param.sharded_param.grad, DTensor)
634634
fsdp_param.sharded_param.grad._local_tensor += new_sharded_grad

0 commit comments

Comments
 (0)