@@ -1392,6 +1392,8 @@ def test_layerwise_casting(storage_dtype, compute_dtype):
13921392
13931393 @require_torch_gpu
13941394 def test_layerwise_casting_memory (self ):
1395+ MB_TOLERANCE = 0.2
1396+
13951397 def reset_memory_stats ():
13961398 gc .collect ()
13971399 torch .cuda .synchronize ()
@@ -1409,17 +1411,25 @@ def get_memory_usage(storage_dtype, compute_dtype):
14091411 reset_memory_stats ()
14101412 model (** inputs_dict )
14111413 model_memory_footprint = model .get_memory_footprint ()
1412- peak_inference_memory_allocated = torch .cuda .max_memory_allocated ()
1414+ peak_inference_memory_allocated_mb = torch .cuda .max_memory_allocated () / 1024 ** 2
14131415
1414- return model_memory_footprint , peak_inference_memory_allocated
1416+ return model_memory_footprint , peak_inference_memory_allocated_mb
14151417
1418+ fp32_memory_footprint , fp32_max_memory = get_memory_usage (torch .float32 , torch .float32 )
14161419 fp8_e4m3_fp32_memory_footprint , fp8_e4m3_fp32_max_memory = get_memory_usage (torch .float8_e4m3fn , torch .float32 )
14171420 fp8_e4m3_bf16_memory_footprint , fp8_e4m3_bf16_max_memory = get_memory_usage (
14181421 torch .float8_e4m3fn , torch .bfloat16
14191422 )
14201423
1421- self .assertTrue (fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint )
1424+ self .assertTrue (fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint )
14221425 self .assertTrue (fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory )
1426+ # On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few
1427+ # bytes. This only happens for some models, so we allow a small tolerance.
1428+ # For any real model being tested, the order would be fp8_e4m3_bf16 < fp8_e4m3_fp32 < fp32.
1429+ self .assertTrue (
1430+ fp8_e4m3_fp32_max_memory < fp32_max_memory
1431+ or abs (fp8_e4m3_fp32_max_memory - fp32_max_memory ) < MB_TOLERANCE
1432+ )
14231433
14241434
14251435@is_staging_test
0 commit comments