@@ -380,26 +380,31 @@ def _eval_models(
380
380
m2 : DistributedModelParallel ,
381
381
batch : ModelInput ,
382
382
is_deterministic : bool = True ,
383
+ tolerance : Optional [float ] = None ,
383
384
) -> None :
384
385
with torch .no_grad ():
385
386
loss1 , pred1 = m1 (batch )
386
387
loss2 , pred2 = m2 (batch )
387
-
388
388
if is_deterministic :
389
389
self .assertTrue (torch .equal (loss1 , loss2 ))
390
390
self .assertTrue (torch .equal (pred1 , pred2 ))
391
391
else :
392
- rtol , atol = _get_default_rtol_and_atol (loss1 , loss2 )
393
- torch .testing .assert_close (loss1 , loss2 , rtol = rtol , atol = atol )
394
- rtol , atol = _get_default_rtol_and_atol (pred1 , pred2 )
395
- torch .testing .assert_close (pred1 , pred2 , rtol = rtol , atol = atol )
392
+ if tolerance :
393
+ torch .testing .assert_close (loss1 , loss2 , rtol = tolerance , atol = tolerance )
394
+ torch .testing .assert_close (pred1 , pred2 , rtol = tolerance , atol = tolerance )
395
+ else :
396
+ rtol , atol = _get_default_rtol_and_atol (loss1 , loss2 )
397
+ torch .testing .assert_close (loss1 , loss2 , rtol = rtol , atol = atol )
398
+ rtol , atol = _get_default_rtol_and_atol (pred1 , pred2 )
399
+ torch .testing .assert_close (pred1 , pred2 , rtol = rtol , atol = atol )
396
400
397
401
def _compare_models (
398
402
self ,
399
403
m1 : DistributedModelParallel ,
400
404
m2 : DistributedModelParallel ,
401
405
is_deterministic : bool = True ,
402
406
use_virtual_table : bool = False ,
407
+ tolerance : Optional [float ] = None ,
403
408
) -> None :
404
409
sd1 = m1 .state_dict ()
405
410
sd2 = m2 .state_dict ()
@@ -437,7 +442,12 @@ def _compare_models(
437
442
if is_deterministic :
438
443
self .assertTrue (torch .allclose (src_tensor , dst_tensor ))
439
444
else :
440
- rtol , atol = _get_default_rtol_and_atol (src_tensor , dst_tensor )
445
+ if tolerance :
446
+ rtol , atol = tolerance , tolerance
447
+ else :
448
+ rtol , atol = _get_default_rtol_and_atol (
449
+ src_tensor , dst_tensor
450
+ )
441
451
torch .testing .assert_close (
442
452
src_tensor , dst_tensor , rtol = rtol , atol = atol
443
453
)
@@ -453,7 +463,10 @@ def _compare_models(
453
463
if is_deterministic :
454
464
self .assertTrue (torch .equal (src , dst ))
455
465
else :
456
- rtol , atol = _get_default_rtol_and_atol (src , dst )
466
+ if tolerance :
467
+ rtol , atol = tolerance , tolerance
468
+ else :
469
+ rtol , atol = _get_default_rtol_and_atol (src , dst )
457
470
torch .testing .assert_close (
458
471
src ._local_tensor , dst ._local_tensor , rtol = rtol , atol = atol
459
472
)
@@ -463,7 +476,10 @@ def _compare_models(
463
476
if is_deterministic :
464
477
self .assertTrue (torch .equal (src , dst ))
465
478
else :
466
- rtol , atol = _get_default_rtol_and_atol (src , dst )
479
+ if tolerance :
480
+ rtol , atol = tolerance , tolerance
481
+ else :
482
+ rtol , atol = _get_default_rtol_and_atol (src , dst )
467
483
torch .testing .assert_close (src , dst , rtol = rtol , atol = atol )
468
484
469
485
0 commit comments