@@ -335,7 +335,7 @@ def test_train_parity_multi_group(self):
335
335
self .run_subtests (
336
336
{
337
337
"reshard_after_forward" : [True , False , 2 ],
338
- "device_type " : [device_type .type ],
338
+ "test_device_type " : [device_type .type ],
339
339
"offload_policy" : [OffloadPolicy ()],
340
340
"delay_after_forward" : [False , True ],
341
341
"delay_before_all_gather" : [False , True ],
@@ -360,7 +360,7 @@ def test_train_parity_multi_group_cpu_offload_eager(self):
360
360
CPUOffloadPolicy (pin_memory = True ),
361
361
CPUOffloadPolicy (pin_memory = False ),
362
362
],
363
- "device_type " : [device_type .type ],
363
+ "test_device_type " : [device_type .type ],
364
364
"delay_after_forward" : [False , True ],
365
365
"delay_before_all_gather" : [False , True ],
366
366
"delay_before_reduce_scatter" : [False , True ],
@@ -381,7 +381,7 @@ def test_train_parity_multi_group_unshard_async_op(self):
381
381
self .run_subtests (
382
382
{
383
383
"reshard_after_forward" : [True ],
384
- "device_type " : [device_type .type ],
384
+ "test_device_type " : [device_type .type ],
385
385
"offload_policy" : [OffloadPolicy ()],
386
386
"delay_after_forward" : [False , True ],
387
387
"delay_before_all_gather" : [False , True ],
@@ -396,7 +396,7 @@ def _test_train_parity_multi_group(
396
396
self ,
397
397
reshard_after_forward : Union [bool , int ],
398
398
offload_policy : OffloadPolicy ,
399
- device_type : str ,
399
+ test_device_type : str ,
400
400
delay_after_forward : bool ,
401
401
delay_before_all_gather : bool ,
402
402
delay_before_reduce_scatter : bool ,
@@ -412,7 +412,7 @@ def _test_train_parity_multi_group(
412
412
in (2 , 3 )
413
413
):
414
414
return
415
- assert device_type in ("cuda" , "hpu" , "xpu" , "cpu" ), f"{ device_type } "
415
+ assert test_device_type in ("cuda" , "hpu" , "xpu" , "cpu" ), f"{ test_device_type } "
416
416
torch .manual_seed (42 )
417
417
vocab_size = 1024
418
418
model_args = ModelArgs (
@@ -424,7 +424,7 @@ def _test_train_parity_multi_group(
424
424
)
425
425
model = Transformer (model_args )
426
426
ref_model = copy .deepcopy (model )
427
- if device_type == device_type :
427
+ if test_device_type == device_type . type :
428
428
replicate (
429
429
ref_model .to (device_type ),
430
430
device_ids = [self .rank ],
@@ -433,7 +433,7 @@ def _test_train_parity_multi_group(
433
433
gloo_pg = dist .new_group (backend = "gloo" )
434
434
replicate (ref_model , process_group = gloo_pg )
435
435
ref_optim = torch .optim .Adam (ref_model .parameters (), lr = 1e-2 )
436
- mesh = init_device_mesh (device_type , (self .world_size ,))
436
+ mesh = init_device_mesh (test_device_type , (self .world_size ,))
437
437
fully_shard_fn = functools .partial (
438
438
fully_shard ,
439
439
mesh = mesh ,
@@ -483,12 +483,12 @@ def delayed_reduce_scatter(*args, **kwargs):
483
483
_optim .zero_grad (set_to_none = (iter_idx % 2 == 0 ))
484
484
losses .append (_model (inp ).sum ())
485
485
if _model is model and delay_after_forward :
486
- torch .get_device_module (device_type )._sleep (
486
+ torch .get_device_module (test_device_type )._sleep (
487
487
int (delay_in_ms * get_cycles_per_ms ())
488
488
)
489
489
losses [- 1 ].backward ()
490
490
if _model is model and delay_before_optim :
491
- torch .get_device_module (device_type )._sleep (
491
+ torch .get_device_module (test_device_type )._sleep (
492
492
int (delay_in_ms * get_cycles_per_ms ())
493
493
)
494
494
_optim .step ()
@@ -1360,6 +1360,10 @@ def test_train_parity_hsdp(self):
1360
1360
"use_activation_checkpointing" : [False , True ],
1361
1361
"mlp_dim" : [3 , 16 , 17 ],
1362
1362
"sync_gradients_at_last_batch" : [True , False ],
1363
+ "offload_policy" : [
1364
+ CPUOffloadPolicy (pin_memory = True ),
1365
+ CPUOffloadPolicy (pin_memory = False ),
1366
+ ],
1363
1367
},
1364
1368
functools .partial (self ._test_train_parity_hsdp , global_mesh ),
1365
1369
)
@@ -1371,6 +1375,7 @@ def _test_train_parity_hsdp(
1371
1375
use_activation_checkpointing : bool ,
1372
1376
mlp_dim : int ,
1373
1377
sync_gradients_at_last_batch : bool ,
1378
+ offload_policy : CPUOffloadPolicy ,
1374
1379
):
1375
1380
torch .manual_seed (42 )
1376
1381
model = nn .Sequential (
@@ -1389,10 +1394,16 @@ def _test_train_parity_hsdp(
1389
1394
if use_activation_checkpointing :
1390
1395
checkpoint (mlp )
1391
1396
fully_shard (
1392
- mlp , mesh = global_mesh , reshard_after_forward = reshard_after_forward
1397
+ mlp ,
1398
+ mesh = global_mesh ,
1399
+ reshard_after_forward = reshard_after_forward ,
1400
+ offload_policy = offload_policy ,
1393
1401
)
1394
1402
fully_shard (
1395
- model , mesh = global_mesh , reshard_after_forward = reshard_after_forward
1403
+ model ,
1404
+ mesh = global_mesh ,
1405
+ reshard_after_forward = reshard_after_forward ,
1406
+ offload_policy = offload_policy ,
1396
1407
)
1397
1408
optim = torch .optim .Adam (model .parameters (), lr = 1e-2 )
1398
1409
check_sharded_parity (self , ref_model , model )
0 commit comments