Skip to content

Commit 376adf9

Browse files
committed
add assertion with fp32 comparison; add tolerance to fp8-fp32 vs fp32-fp32 comparison (required for a few models' test to pass)
1 parent 7803364 commit 376adf9

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

tests/models/test_modeling_common.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,6 +1392,8 @@ def test_layerwise_casting(storage_dtype, compute_dtype):
13921392

13931393
@require_torch_gpu
13941394
def test_layerwise_casting_memory(self):
1395+
MB_TOLERANCE = 0.2
1396+
13951397
def reset_memory_stats():
13961398
gc.collect()
13971399
torch.cuda.synchronize()
@@ -1409,17 +1411,25 @@ def get_memory_usage(storage_dtype, compute_dtype):
14091411
reset_memory_stats()
14101412
model(**inputs_dict)
14111413
model_memory_footprint = model.get_memory_footprint()
1412-
peak_inference_memory_allocated = torch.cuda.max_memory_allocated()
1414+
peak_inference_memory_allocated_mb = torch.cuda.max_memory_allocated() / 1024**2
14131415

1414-
return model_memory_footprint, peak_inference_memory_allocated
1416+
return model_memory_footprint, peak_inference_memory_allocated_mb
14151417

1418+
fp32_memory_footprint, fp32_max_memory = get_memory_usage(torch.float32, torch.float32)
14161419
fp8_e4m3_fp32_memory_footprint, fp8_e4m3_fp32_max_memory = get_memory_usage(torch.float8_e4m3fn, torch.float32)
14171420
fp8_e4m3_bf16_memory_footprint, fp8_e4m3_bf16_max_memory = get_memory_usage(
14181421
torch.float8_e4m3fn, torch.bfloat16
14191422
)
14201423

1421-
self.assertTrue(fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint)
1424+
self.assertTrue(fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint)
14221425
self.assertTrue(fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory)
1426+
# On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few
1427+
# bytes. This only happens for some models, so we allow a small tolerance.
1428+
# For any real model being tested, the order would be fp8_e4m3_bf16 < fp8_e4m3_fp32 < fp32.
1429+
self.assertTrue(
1430+
fp8_e4m3_fp32_max_memory < fp32_max_memory
1431+
or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE
1432+
)
14231433

14241434

14251435
@is_staging_test

0 commit comments

Comments
 (0)