Skip to content

Commit 77a32a7

Browse files
committed
refactor tests
1 parent 93bd8ee commit 77a32a7

File tree

1 file changed

+67
-58
lines changed

1 file changed

+67
-58
lines changed

tests/models/test_modeling_common.py

Lines changed: 67 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import inspect
1919
import json
2020
import os
21+
import re
2122
import tempfile
2223
import traceback
2324
import unittest
@@ -183,6 +184,16 @@ def compute_module_persistent_sizes(
183184
return module_sizes
184185

185186

187+
def cast_maybe_tensor_dtype(maybe_tensor, current_dtype, target_dtype):
188+
if torch.is_tensor(maybe_tensor):
189+
return maybe_tensor.to(target_dtype) if maybe_tensor.dtype == current_dtype else maybe_tensor
190+
if isinstance(maybe_tensor, dict):
191+
return {k: cast_maybe_tensor_dtype(v, current_dtype, target_dtype) for k, v in maybe_tensor.items()}
192+
if isinstance(maybe_tensor, list):
193+
return [cast_maybe_tensor_dtype(v, current_dtype, target_dtype) for v in maybe_tensor]
194+
return maybe_tensor
195+
196+
186197
class ModelUtilsTest(unittest.TestCase):
187198
def tearDown(self):
188199
super().tearDown()
@@ -1334,80 +1345,78 @@ def test_variant_sharded_ckpt_right_format(self):
13341345
assert all(f.split(".")[1].split("-")[0] == variant for f in shard_files)
13351346

13361347
def test_layerwise_upcasting_inference(self):
1337-
torch.manual_seed(0)
1338-
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1339-
model = self.model_class(**config).eval()
1340-
model = model.to(torch_device)
1341-
base_slice = model(**inputs_dict)[0].flatten().detach().cpu().numpy()
1348+
from diffusers.hooks.layerwise_upcasting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
13421349

1343-
# fp16-fp32
13441350
torch.manual_seed(0)
13451351
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
13461352
model = self.model_class(**config).eval()
13471353
model = model.to(torch_device)
1348-
model.enable_layerwise_upcasting(storage_dtype=torch.float16, compute_dtype=torch.float32)
1349-
layerwise_upcast_slice_fp16 = model(**inputs_dict)[0].flatten().detach().cpu().numpy()
1354+
base_slice = model(**inputs_dict)[0].flatten().detach().cpu().numpy()
13501355

1351-
# The precision test is not very important for fast tests. In most cases, the outputs will not be the same.
1352-
# We just want to make sure that the layerwise upcasting is working as expected.
1353-
self.assertTrue(numpy_cosine_similarity_distance(base_slice, layerwise_upcast_slice_fp16) < 1.0)
1356+
def check_linear_dtype(module, storage_dtype, compute_dtype):
1357+
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
1358+
if getattr(module, "_always_upcast_modules", None) is not None:
1359+
patterns_to_check += tuple(module._always_upcast_modules)
1360+
for name, submodule in module.named_modules():
1361+
if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS):
1362+
continue
1363+
dtype_to_check = storage_dtype
1364+
if any(re.search(pattern, name) for pattern in patterns_to_check):
1365+
dtype_to_check = compute_dtype
1366+
if getattr(submodule, "weight", None) is not None:
1367+
self.assertEqual(submodule.weight.dtype, dtype_to_check)
1368+
if getattr(submodule, "bias", None) is not None:
1369+
self.assertEqual(submodule.bias.dtype, dtype_to_check)
1370+
1371+
def test_layerwise_upcasting(storage_dtype, compute_dtype):
1372+
torch.manual_seed(0)
1373+
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1374+
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
1375+
model = self.model_class(**config).eval()
1376+
model = model.to(torch_device, dtype=compute_dtype)
1377+
model.enable_layerwise_upcasting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
1378+
check_linear_dtype(model, storage_dtype, compute_dtype)
1379+
output = model(**inputs_dict)[0].float().flatten().detach().cpu().numpy()
13541380

1355-
# fp8_e4m3-fp32
1356-
torch.manual_seed(0)
1357-
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1358-
model = self.model_class(**config).eval()
1359-
model = model.to(torch_device)
1360-
model.enable_layerwise_upcasting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32)
1361-
layerwise_upcast_slice_fp8_e4m3 = model(**inputs_dict)[0].flatten().detach().cpu().numpy()
1381+
# The precision test is not very important for fast tests. In most cases, the outputs will not be the same.
1382+
# We just want to make sure that the layerwise upcasting is working as expected.
1383+
self.assertTrue(numpy_cosine_similarity_distance(base_slice, output) < 1.0)
13621384

1363-
self.assertTrue(numpy_cosine_similarity_distance(base_slice, layerwise_upcast_slice_fp8_e4m3) < 1.0)
1364-
1365-
# fp8_e5m2-fp32
1366-
torch.manual_seed(0)
1367-
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1368-
model = self.model_class(**config).eval()
1369-
model = model.to(torch_device)
1370-
model.enable_layerwise_upcasting(storage_dtype=torch.float8_e5m2, compute_dtype=torch.float32)
1371-
layerwise_upcast_slice_fp8_e5m2 = model(**inputs_dict)[0].flatten().detach().cpu().numpy()
1372-
1373-
self.assertTrue(numpy_cosine_similarity_distance(base_slice, layerwise_upcast_slice_fp8_e5m2) < 1.0)
1385+
test_layerwise_upcasting(torch.float16, torch.float32)
1386+
test_layerwise_upcasting(torch.float8_e4m3fn, torch.float32)
1387+
test_layerwise_upcasting(torch.float8_e5m2, torch.float32)
1388+
test_layerwise_upcasting(torch.float8_e4m3fn, torch.bfloat16)
13741389

13751390
@require_torch_gpu
13761391
def test_layerwise_upcasting_memory(self):
1377-
# fp32
1378-
gc.collect()
1379-
torch.cuda.empty_cache()
1380-
torch.cuda.reset_peak_memory_stats()
1381-
torch.cuda.synchronize()
1392+
def reset_memory_stats():
1393+
gc.collect()
1394+
torch.cuda.synchronize()
1395+
torch.cuda.empty_cache()
1396+
torch.cuda.reset_peak_memory_stats()
13821397

1383-
torch.manual_seed(0)
1384-
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1385-
model = self.model_class(**config).eval()
1386-
model = model.to(torch_device)
1387-
model(**inputs_dict)
1388-
base_memory_footprint = model.get_memory_footprint()
1389-
base_max_memory = torch.cuda.max_memory_allocated()
1398+
def get_memory_usage(storage_dtype, compute_dtype):
1399+
torch.manual_seed(0)
1400+
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1401+
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
1402+
model = self.model_class(**config).eval()
1403+
model = model.to(torch_device, dtype=compute_dtype)
1404+
model.enable_layerwise_upcasting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
13901405

1391-
model.to("cpu")
1392-
del model
1406+
reset_memory_stats()
1407+
model(**inputs_dict)
1408+
model_memory_footprint = model.get_memory_footprint()
1409+
peak_inference_memory_allocated = torch.cuda.max_memory_allocated()
13931410

1394-
# fp8_e4m3-fp32
1395-
gc.collect()
1396-
torch.cuda.empty_cache()
1397-
torch.cuda.reset_peak_memory_stats()
1398-
torch.cuda.synchronize()
1411+
return model_memory_footprint, peak_inference_memory_allocated
13991412

1400-
torch.manual_seed(0)
1401-
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1402-
model = self.model_class(**config).eval()
1403-
model = model.to(torch_device)
1404-
model.enable_layerwise_upcasting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32)
1405-
model(**inputs_dict)
1406-
fp8_e4m3_memory_footprint = model.get_memory_footprint()
1407-
fp8_e4m3_max_memory = torch.cuda.max_memory_allocated()
1413+
fp8_e4m3_fp32_memory_footprint, fp8_e4m3_fp32_max_memory = get_memory_usage(torch.float8_e4m3fn, torch.float32)
1414+
fp8_e4m3_bf16_memory_footprint, fp8_e4m3_bf16_max_memory = get_memory_usage(
1415+
torch.float8_e4m3fn, torch.bfloat16
1416+
)
14081417

1409-
self.assertTrue(fp8_e4m3_memory_footprint < base_memory_footprint)
1410-
self.assertTrue(fp8_e4m3_max_memory < base_max_memory)
1418+
self.assertTrue(fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint)
1419+
self.assertTrue(fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory)
14111420

14121421

14131422
@is_staging_test

0 commit comments

Comments
 (0)