|
18 | 18 | import inspect |
19 | 19 | import json |
20 | 20 | import os |
| 21 | +import re |
21 | 22 | import tempfile |
22 | 23 | import traceback |
23 | 24 | import unittest |
@@ -183,6 +184,16 @@ def compute_module_persistent_sizes( |
183 | 184 | return module_sizes |
184 | 185 |
|
185 | 186 |
|
| 187 | +def cast_maybe_tensor_dtype(maybe_tensor, current_dtype, target_dtype): |
| 188 | + if torch.is_tensor(maybe_tensor): |
| 189 | + return maybe_tensor.to(target_dtype) if maybe_tensor.dtype == current_dtype else maybe_tensor |
| 190 | + if isinstance(maybe_tensor, dict): |
| 191 | + return {k: cast_maybe_tensor_dtype(v, current_dtype, target_dtype) for k, v in maybe_tensor.items()} |
| 192 | + if isinstance(maybe_tensor, list): |
| 193 | + return [cast_maybe_tensor_dtype(v, current_dtype, target_dtype) for v in maybe_tensor] |
| 194 | + return maybe_tensor |
| 195 | + |
| 196 | + |
186 | 197 | class ModelUtilsTest(unittest.TestCase): |
187 | 198 | def tearDown(self): |
188 | 199 | super().tearDown() |
@@ -1334,80 +1345,78 @@ def test_variant_sharded_ckpt_right_format(self): |
1334 | 1345 | assert all(f.split(".")[1].split("-")[0] == variant for f in shard_files) |
1335 | 1346 |
|
1336 | 1347 | def test_layerwise_upcasting_inference(self): |
1337 | | - torch.manual_seed(0) |
1338 | | - config, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
1339 | | - model = self.model_class(**config).eval() |
1340 | | - model = model.to(torch_device) |
1341 | | - base_slice = model(**inputs_dict)[0].flatten().detach().cpu().numpy() |
| 1348 | + from diffusers.hooks.layerwise_upcasting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS |
1342 | 1349 |
|
1343 | | - # fp16-fp32 |
1344 | 1350 | torch.manual_seed(0) |
1345 | 1351 | config, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
1346 | 1352 | model = self.model_class(**config).eval() |
1347 | 1353 | model = model.to(torch_device) |
1348 | | - model.enable_layerwise_upcasting(storage_dtype=torch.float16, compute_dtype=torch.float32) |
1349 | | - layerwise_upcast_slice_fp16 = model(**inputs_dict)[0].flatten().detach().cpu().numpy() |
| 1354 | + base_slice = model(**inputs_dict)[0].flatten().detach().cpu().numpy() |
1350 | 1355 |
|
1351 | | - # The precision test is not very important for fast tests. In most cases, the outputs will not be the same. |
1352 | | - # We just want to make sure that the layerwise upcasting is working as expected. |
1353 | | - self.assertTrue(numpy_cosine_similarity_distance(base_slice, layerwise_upcast_slice_fp16) < 1.0) |
| 1356 | + def check_linear_dtype(module, storage_dtype, compute_dtype): |
| 1357 | + patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN |
| 1358 | + if getattr(module, "_always_upcast_modules", None) is not None: |
| 1359 | + patterns_to_check += tuple(module._always_upcast_modules) |
| 1360 | + for name, submodule in module.named_modules(): |
| 1361 | + if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS): |
| 1362 | + continue |
| 1363 | + dtype_to_check = storage_dtype |
| 1364 | + if any(re.search(pattern, name) for pattern in patterns_to_check): |
| 1365 | + dtype_to_check = compute_dtype |
| 1366 | + if getattr(submodule, "weight", None) is not None: |
| 1367 | + self.assertEqual(submodule.weight.dtype, dtype_to_check) |
| 1368 | + if getattr(submodule, "bias", None) is not None: |
| 1369 | + self.assertEqual(submodule.bias.dtype, dtype_to_check) |
| 1370 | + |
| 1371 | + def test_layerwise_upcasting(storage_dtype, compute_dtype): |
| 1372 | + torch.manual_seed(0) |
| 1373 | + config, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
| 1374 | + inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype) |
| 1375 | + model = self.model_class(**config).eval() |
| 1376 | + model = model.to(torch_device, dtype=compute_dtype) |
| 1377 | + model.enable_layerwise_upcasting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) |
| 1378 | + check_linear_dtype(model, storage_dtype, compute_dtype) |
| 1379 | + output = model(**inputs_dict)[0].float().flatten().detach().cpu().numpy() |
1354 | 1380 |
|
1355 | | - # fp8_e4m3-fp32 |
1356 | | - torch.manual_seed(0) |
1357 | | - config, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
1358 | | - model = self.model_class(**config).eval() |
1359 | | - model = model.to(torch_device) |
1360 | | - model.enable_layerwise_upcasting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32) |
1361 | | - layerwise_upcast_slice_fp8_e4m3 = model(**inputs_dict)[0].flatten().detach().cpu().numpy() |
| 1381 | + # The precision test is not very important for fast tests. In most cases, the outputs will not be the same. |
| 1382 | + # We just want to make sure that the layerwise upcasting is working as expected. |
| 1383 | + self.assertTrue(numpy_cosine_similarity_distance(base_slice, output) < 1.0) |
1362 | 1384 |
|
1363 | | - self.assertTrue(numpy_cosine_similarity_distance(base_slice, layerwise_upcast_slice_fp8_e4m3) < 1.0) |
1364 | | - |
1365 | | - # fp8_e5m2-fp32 |
1366 | | - torch.manual_seed(0) |
1367 | | - config, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
1368 | | - model = self.model_class(**config).eval() |
1369 | | - model = model.to(torch_device) |
1370 | | - model.enable_layerwise_upcasting(storage_dtype=torch.float8_e5m2, compute_dtype=torch.float32) |
1371 | | - layerwise_upcast_slice_fp8_e5m2 = model(**inputs_dict)[0].flatten().detach().cpu().numpy() |
1372 | | - |
1373 | | - self.assertTrue(numpy_cosine_similarity_distance(base_slice, layerwise_upcast_slice_fp8_e5m2) < 1.0) |
| 1385 | + test_layerwise_upcasting(torch.float16, torch.float32) |
| 1386 | + test_layerwise_upcasting(torch.float8_e4m3fn, torch.float32) |
| 1387 | + test_layerwise_upcasting(torch.float8_e5m2, torch.float32) |
| 1388 | + test_layerwise_upcasting(torch.float8_e4m3fn, torch.bfloat16) |
1374 | 1389 |
|
1375 | 1390 | @require_torch_gpu |
1376 | 1391 | def test_layerwise_upcasting_memory(self): |
1377 | | - # fp32 |
1378 | | - gc.collect() |
1379 | | - torch.cuda.empty_cache() |
1380 | | - torch.cuda.reset_peak_memory_stats() |
1381 | | - torch.cuda.synchronize() |
| 1392 | + def reset_memory_stats(): |
| 1393 | + gc.collect() |
| 1394 | + torch.cuda.synchronize() |
| 1395 | + torch.cuda.empty_cache() |
| 1396 | + torch.cuda.reset_peak_memory_stats() |
1382 | 1397 |
|
1383 | | - torch.manual_seed(0) |
1384 | | - config, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
1385 | | - model = self.model_class(**config).eval() |
1386 | | - model = model.to(torch_device) |
1387 | | - model(**inputs_dict) |
1388 | | - base_memory_footprint = model.get_memory_footprint() |
1389 | | - base_max_memory = torch.cuda.max_memory_allocated() |
| 1398 | + def get_memory_usage(storage_dtype, compute_dtype): |
| 1399 | + torch.manual_seed(0) |
| 1400 | + config, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
| 1401 | + inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype) |
| 1402 | + model = self.model_class(**config).eval() |
| 1403 | + model = model.to(torch_device, dtype=compute_dtype) |
| 1404 | + model.enable_layerwise_upcasting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) |
1390 | 1405 |
|
1391 | | - model.to("cpu") |
1392 | | - del model |
| 1406 | + reset_memory_stats() |
| 1407 | + model(**inputs_dict) |
| 1408 | + model_memory_footprint = model.get_memory_footprint() |
| 1409 | + peak_inference_memory_allocated = torch.cuda.max_memory_allocated() |
1393 | 1410 |
|
1394 | | - # fp8_e4m3-fp32 |
1395 | | - gc.collect() |
1396 | | - torch.cuda.empty_cache() |
1397 | | - torch.cuda.reset_peak_memory_stats() |
1398 | | - torch.cuda.synchronize() |
| 1411 | + return model_memory_footprint, peak_inference_memory_allocated |
1399 | 1412 |
|
1400 | | - torch.manual_seed(0) |
1401 | | - config, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
1402 | | - model = self.model_class(**config).eval() |
1403 | | - model = model.to(torch_device) |
1404 | | - model.enable_layerwise_upcasting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32) |
1405 | | - model(**inputs_dict) |
1406 | | - fp8_e4m3_memory_footprint = model.get_memory_footprint() |
1407 | | - fp8_e4m3_max_memory = torch.cuda.max_memory_allocated() |
| 1413 | + fp8_e4m3_fp32_memory_footprint, fp8_e4m3_fp32_max_memory = get_memory_usage(torch.float8_e4m3fn, torch.float32) |
| 1414 | + fp8_e4m3_bf16_memory_footprint, fp8_e4m3_bf16_max_memory = get_memory_usage( |
| 1415 | + torch.float8_e4m3fn, torch.bfloat16 |
| 1416 | + ) |
1408 | 1417 |
|
1409 | | - self.assertTrue(fp8_e4m3_memory_footprint < base_memory_footprint) |
1410 | | - self.assertTrue(fp8_e4m3_max_memory < base_max_memory) |
| 1418 | + self.assertTrue(fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint) |
| 1419 | + self.assertTrue(fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory) |
1411 | 1420 |
|
1412 | 1421 |
|
1413 | 1422 | @is_staging_test |
|
0 commit comments