Skip to content

Commit 2e29099

Browse files
kausvfacebook-github-bot
authored andcommitted
Remove flakiness of test_kv_zch_load_state_dict (#3292)
Summary: Pull Request resolved: #3292 Test link: https://www.internalfb.com/intern/test/281475203207916 The test is flaky because KVZCH kernel [guarantees accuracy of 1e-2](https://www.internalfb.com/code/fbsource/[35a43c0e43e5]/fbcode/deeplearning/fbgemm/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py?lines=1399-1402) for FP16. I changed test_model_parallel_base to accept custom tolerance to override default atol/rtol and added the tolerance to this test to resolve the flakiness Reviewed By: duduyi2013 Differential Revision: D80457783 fbshipit-source-id: 07720dfceb5a2d393bff2fa2e4e0b0f81c7cac6e
1 parent 0f323fb commit 2e29099

File tree

2 files changed

+33
-11
lines changed

2 files changed

+33
-11
lines changed

torchrec/distributed/test_utils/test_model_parallel_base.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -380,26 +380,31 @@ def _eval_models(
380380
m2: DistributedModelParallel,
381381
batch: ModelInput,
382382
is_deterministic: bool = True,
383+
tolerance: Optional[float] = None,
383384
) -> None:
384385
with torch.no_grad():
385386
loss1, pred1 = m1(batch)
386387
loss2, pred2 = m2(batch)
387-
388388
if is_deterministic:
389389
self.assertTrue(torch.equal(loss1, loss2))
390390
self.assertTrue(torch.equal(pred1, pred2))
391391
else:
392-
rtol, atol = _get_default_rtol_and_atol(loss1, loss2)
393-
torch.testing.assert_close(loss1, loss2, rtol=rtol, atol=atol)
394-
rtol, atol = _get_default_rtol_and_atol(pred1, pred2)
395-
torch.testing.assert_close(pred1, pred2, rtol=rtol, atol=atol)
392+
if tolerance:
393+
torch.testing.assert_close(loss1, loss2, rtol=tolerance, atol=tolerance)
394+
torch.testing.assert_close(pred1, pred2, rtol=tolerance, atol=tolerance)
395+
else:
396+
rtol, atol = _get_default_rtol_and_atol(loss1, loss2)
397+
torch.testing.assert_close(loss1, loss2, rtol=rtol, atol=atol)
398+
rtol, atol = _get_default_rtol_and_atol(pred1, pred2)
399+
torch.testing.assert_close(pred1, pred2, rtol=rtol, atol=atol)
396400

397401
def _compare_models(
398402
self,
399403
m1: DistributedModelParallel,
400404
m2: DistributedModelParallel,
401405
is_deterministic: bool = True,
402406
use_virtual_table: bool = False,
407+
tolerance: Optional[float] = None,
403408
) -> None:
404409
sd1 = m1.state_dict()
405410
sd2 = m2.state_dict()
@@ -437,7 +442,12 @@ def _compare_models(
437442
if is_deterministic:
438443
self.assertTrue(torch.allclose(src_tensor, dst_tensor))
439444
else:
440-
rtol, atol = _get_default_rtol_and_atol(src_tensor, dst_tensor)
445+
if tolerance:
446+
rtol, atol = tolerance, tolerance
447+
else:
448+
rtol, atol = _get_default_rtol_and_atol(
449+
src_tensor, dst_tensor
450+
)
441451
torch.testing.assert_close(
442452
src_tensor, dst_tensor, rtol=rtol, atol=atol
443453
)
@@ -453,7 +463,10 @@ def _compare_models(
453463
if is_deterministic:
454464
self.assertTrue(torch.equal(src, dst))
455465
else:
456-
rtol, atol = _get_default_rtol_and_atol(src, dst)
466+
if tolerance:
467+
rtol, atol = tolerance, tolerance
468+
else:
469+
rtol, atol = _get_default_rtol_and_atol(src, dst)
457470
torch.testing.assert_close(
458471
src._local_tensor, dst._local_tensor, rtol=rtol, atol=atol
459472
)
@@ -463,7 +476,10 @@ def _compare_models(
463476
if is_deterministic:
464477
self.assertTrue(torch.equal(src, dst))
465478
else:
466-
rtol, atol = _get_default_rtol_and_atol(src, dst)
479+
if tolerance:
480+
rtol, atol = tolerance, tolerance
481+
else:
482+
rtol, atol = _get_default_rtol_and_atol(src, dst)
467483
torch.testing.assert_close(src, dst, rtol=rtol, atol=atol)
468484

469485

torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,7 +1022,7 @@ def test_kv_zch_load_state_dict(
10221022
"learning_rate": 0.1,
10231023
"stochastic_rounding": stochastic_rounding,
10241024
}
1025-
is_deterministic = dtype == DataType.FP32 or not stochastic_rounding
1025+
is_deterministic = dtype == DataType.FP32
10261026
constraints = {
10271027
table.name: ParameterConstraints(
10281028
sharding_types=[sharding_type],
@@ -1049,9 +1049,15 @@ def test_kv_zch_load_state_dict(
10491049

10501050
if is_training:
10511051
self._train_models(m1, m2, batch)
1052-
self._eval_models(m1, m2, batch, is_deterministic=is_deterministic)
1052+
self._eval_models(
1053+
m1, m2, batch, is_deterministic=is_deterministic, tolerance=1e-2
1054+
)
10531055
self._compare_models(
1054-
m1, m2, is_deterministic=is_deterministic, use_virtual_table=True
1056+
m1,
1057+
m2,
1058+
is_deterministic=is_deterministic,
1059+
use_virtual_table=True,
1060+
tolerance=1e-2,
10551061
)
10561062

10571063
@unittest.skipIf(

0 commit comments

Comments
 (0)