|  | 
| 14 | 14 | # limitations under the License. | 
| 15 | 15 | 
 | 
| 16 | 16 | import copy | 
|  | 17 | +import gc | 
| 17 | 18 | import inspect | 
| 18 | 19 | import json | 
| 19 | 20 | import os | 
|  | 
| 56 | 57 |     CaptureLogger, | 
| 57 | 58 |     get_python_version, | 
| 58 | 59 |     is_torch_compile, | 
|  | 60 | +    numpy_cosine_similarity_distance, | 
| 59 | 61 |     require_torch_2, | 
| 60 | 62 |     require_torch_accelerator_with_training, | 
| 61 | 63 |     require_torch_gpu, | 
| @@ -1331,6 +1333,82 @@ def test_variant_sharded_ckpt_right_format(self): | 
| 1331 | 1333 |                 # Example: diffusion_pytorch_model.fp16-00001-of-00002.safetensors | 
| 1332 | 1334 |                 assert all(f.split(".")[1].split("-")[0] == variant for f in shard_files) | 
| 1333 | 1335 | 
 | 
|  | 1336 | +    def test_layerwise_upcasting_inference(self): | 
|  | 1337 | +        torch.manual_seed(0) | 
|  | 1338 | +        config, inputs_dict = self.prepare_init_args_and_inputs_for_common() | 
|  | 1339 | +        model = self.model_class(**config).eval() | 
|  | 1340 | +        model = model.to(torch_device) | 
|  | 1341 | +        base_slice = model(**inputs_dict)[0].flatten().detach().cpu().numpy() | 
|  | 1342 | + | 
|  | 1343 | +        # fp16-fp32 | 
|  | 1344 | +        torch.manual_seed(0) | 
|  | 1345 | +        config, inputs_dict = self.prepare_init_args_and_inputs_for_common() | 
|  | 1346 | +        model = self.model_class(**config).eval() | 
|  | 1347 | +        model = model.to(torch_device) | 
|  | 1348 | +        model.enable_layerwise_upcasting(storage_dtype=torch.float16, compute_dtype=torch.float32) | 
|  | 1349 | +        layerwise_upcast_slice_fp16 = model(**inputs_dict)[0].flatten().detach().cpu().numpy() | 
|  | 1350 | + | 
|  | 1351 | +        # The precision test is not very important for fast tests. In most cases, the outputs will not be the same. | 
|  | 1352 | +        # We just want to make sure that the layerwise upcasting is working as expected. | 
|  | 1353 | +        self.assertTrue(numpy_cosine_similarity_distance(base_slice, layerwise_upcast_slice_fp16) < 1.0) | 
|  | 1354 | + | 
|  | 1355 | +        # fp8_e4m3-fp32 | 
|  | 1356 | +        torch.manual_seed(0) | 
|  | 1357 | +        config, inputs_dict = self.prepare_init_args_and_inputs_for_common() | 
|  | 1358 | +        model = self.model_class(**config).eval() | 
|  | 1359 | +        model = model.to(torch_device) | 
|  | 1360 | +        model.enable_layerwise_upcasting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32) | 
|  | 1361 | +        layerwise_upcast_slice_fp8_e4m3 = model(**inputs_dict)[0].flatten().detach().cpu().numpy() | 
|  | 1362 | + | 
|  | 1363 | +        self.assertTrue(numpy_cosine_similarity_distance(base_slice, layerwise_upcast_slice_fp8_e4m3) < 1.0) | 
|  | 1364 | + | 
|  | 1365 | +        # fp8_e5m2-fp32 | 
|  | 1366 | +        torch.manual_seed(0) | 
|  | 1367 | +        config, inputs_dict = self.prepare_init_args_and_inputs_for_common() | 
|  | 1368 | +        model = self.model_class(**config).eval() | 
|  | 1369 | +        model = model.to(torch_device) | 
|  | 1370 | +        model.enable_layerwise_upcasting(storage_dtype=torch.float8_e5m2, compute_dtype=torch.float32) | 
|  | 1371 | +        layerwise_upcast_slice_fp8_e5m2 = model(**inputs_dict)[0].flatten().detach().cpu().numpy() | 
|  | 1372 | + | 
|  | 1373 | +        self.assertTrue(numpy_cosine_similarity_distance(base_slice, layerwise_upcast_slice_fp8_e5m2) < 1.0) | 
|  | 1374 | + | 
|  | 1375 | +    @require_torch_gpu | 
|  | 1376 | +    def test_layerwise_upcasting_memory(self): | 
|  | 1377 | +        # fp32 | 
|  | 1378 | +        gc.collect() | 
|  | 1379 | +        torch.cuda.empty_cache() | 
|  | 1380 | +        torch.cuda.reset_peak_memory_stats() | 
|  | 1381 | +        torch.cuda.synchronize() | 
|  | 1382 | + | 
|  | 1383 | +        torch.manual_seed(0) | 
|  | 1384 | +        config, inputs_dict = self.prepare_init_args_and_inputs_for_common() | 
|  | 1385 | +        model = self.model_class(**config).eval() | 
|  | 1386 | +        model = model.to(torch_device) | 
|  | 1387 | +        model(**inputs_dict) | 
|  | 1388 | +        base_memory_footprint = model.get_memory_footprint() | 
|  | 1389 | +        base_max_memory = torch.cuda.max_memory_allocated() | 
|  | 1390 | + | 
|  | 1391 | +        model.to("cpu") | 
|  | 1392 | +        del model | 
|  | 1393 | + | 
|  | 1394 | +        # fp8_e4m3-fp32 | 
|  | 1395 | +        gc.collect() | 
|  | 1396 | +        torch.cuda.empty_cache() | 
|  | 1397 | +        torch.cuda.reset_peak_memory_stats() | 
|  | 1398 | +        torch.cuda.synchronize() | 
|  | 1399 | + | 
|  | 1400 | +        torch.manual_seed(0) | 
|  | 1401 | +        config, inputs_dict = self.prepare_init_args_and_inputs_for_common() | 
|  | 1402 | +        model = self.model_class(**config).eval() | 
|  | 1403 | +        model = model.to(torch_device) | 
|  | 1404 | +        model.enable_layerwise_upcasting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32) | 
|  | 1405 | +        model(**inputs_dict) | 
|  | 1406 | +        fp8_e4m3_memory_footprint = model.get_memory_footprint() | 
|  | 1407 | +        fp8_e4m3_max_memory = torch.cuda.max_memory_allocated() | 
|  | 1408 | + | 
|  | 1409 | +        self.assertTrue(fp8_e4m3_memory_footprint < base_memory_footprint) | 
|  | 1410 | +        self.assertTrue(fp8_e4m3_max_memory < base_max_memory) | 
|  | 1411 | + | 
| 1334 | 1412 | 
 | 
| 1335 | 1413 | @is_staging_test | 
| 1336 | 1414 | class ModelPushToHubTester(unittest.TestCase): | 
|  | 
0 commit comments