Skip to content

Commit c062b08

Browse files
committed
fix conflicts.
2 parents 5ad508f + a7e9f85 commit c062b08

File tree

4 files changed

+40
-20
lines changed

4 files changed

+40
-20
lines changed

src/diffusers/utils/testing_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1186,6 +1186,13 @@ def _is_torch_fp64_available(device):
11861186
"mps": 0,
11871187
"default": 0,
11881188
}
1189+
BACKEND_SYNCHRONIZE = {
1190+
"cuda": torch.cuda.synchronize,
1191+
"xpu": getattr(torch.xpu, "synchronize", None),
1192+
"cpu": None,
1193+
"mps": None,
1194+
"default": None,
1195+
}
11891196

11901197

11911198
# This dispatches a defined function according to the accelerator from the function definitions.
@@ -1208,6 +1215,10 @@ def backend_manual_seed(device: str, seed: int):
12081215
return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)
12091216

12101217

1218+
def backend_synchronize(device: str):
1219+
return _device_agnostic_dispatch(device, BACKEND_SYNCHRONIZE)
1220+
1221+
12111222
def backend_empty_cache(device: str):
12121223
return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE)
12131224

tests/models/test_modeling_common.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@
5959
from diffusers.utils.testing_utils import (
6060
CaptureLogger,
6161
backend_empty_cache,
62+
backend_max_memory_allocated,
63+
backend_reset_peak_memory_stats,
64+
backend_synchronize,
6265
get_python_version,
6366
is_torch_compile,
6467
numpy_cosine_similarity_distance,
@@ -340,7 +343,7 @@ def test_weight_overwrite(self):
340343

341344
assert model.config.in_channels == 9
342345

343-
@require_torch_gpu
346+
@require_torch_accelerator
344347
def test_keep_modules_in_fp32(self):
345348
r"""
346349
A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32 when we load the model in fp16/bf16
@@ -1479,16 +1482,16 @@ def test_layerwise_casting(storage_dtype, compute_dtype):
14791482
test_layerwise_casting(torch.float8_e5m2, torch.float32)
14801483
test_layerwise_casting(torch.float8_e4m3fn, torch.bfloat16)
14811484

1482-
@require_torch_gpu
1485+
@require_torch_accelerator
14831486
def test_layerwise_casting_memory(self):
14841487
MB_TOLERANCE = 0.2
14851488
LEAST_COMPUTE_CAPABILITY = 8.0
14861489

14871490
def reset_memory_stats():
14881491
gc.collect()
1489-
torch.cuda.synchronize()
1490-
torch.cuda.empty_cache()
1491-
torch.cuda.reset_peak_memory_stats()
1492+
backend_synchronize(torch_device)
1493+
backend_empty_cache(torch_device)
1494+
backend_reset_peak_memory_stats(torch_device)
14921495

14931496
def get_memory_usage(storage_dtype, compute_dtype):
14941497
torch.manual_seed(0)
@@ -1501,7 +1504,7 @@ def get_memory_usage(storage_dtype, compute_dtype):
15011504
reset_memory_stats()
15021505
model(**inputs_dict)
15031506
model_memory_footprint = model.get_memory_footprint()
1504-
peak_inference_memory_allocated_mb = torch.cuda.max_memory_allocated() / 1024**2
1507+
peak_inference_memory_allocated_mb = backend_max_memory_allocated(torch_device) / 1024**2
15051508

15061509
return model_memory_footprint, peak_inference_memory_allocated_mb
15071510

@@ -1511,7 +1514,7 @@ def get_memory_usage(storage_dtype, compute_dtype):
15111514
torch.float8_e4m3fn, torch.bfloat16
15121515
)
15131516

1514-
compute_capability = get_torch_cuda_device_capability()
1517+
compute_capability = get_torch_cuda_device_capability() if torch_device == "cuda" else None
15151518
self.assertTrue(fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint)
15161519
# NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32.
15171520
# On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it.
@@ -1526,7 +1529,7 @@ def get_memory_usage(storage_dtype, compute_dtype):
15261529
)
15271530

15281531
@parameterized.expand([False, True])
1529-
@require_torch_gpu
1532+
@require_torch_accelerator
15301533
def test_group_offloading(self, record_stream):
15311534
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
15321535
torch.manual_seed(0)

tests/pipelines/test_pipelines_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
require_accelerator,
5454
require_hf_hub_version_greater,
5555
require_torch,
56-
require_torch_gpu,
56+
require_torch_accelerator,
5757
require_transformers_version_greater,
5858
skip_mps,
5959
torch_device,
@@ -2212,7 +2212,7 @@ def test_layerwise_casting_inference(self):
22122212
inputs = self.get_dummy_inputs(torch_device)
22132213
_ = pipe(**inputs)[0]
22142214

2215-
@require_torch_gpu
2215+
@require_torch_accelerator
22162216
def test_group_offloading_inference(self):
22172217
if not self.test_group_offloading:
22182218
return

tests/quantization/quanto/test_quanto.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,13 @@
66
from diffusers.models.attention_processor import Attention
77
from diffusers.utils import is_optimum_quanto_available, is_torch_available
88
from diffusers.utils.testing_utils import (
9+
backend_empty_cache,
10+
backend_reset_peak_memory_stats,
11+
enable_full_determinism,
912
nightly,
1013
numpy_cosine_similarity_distance,
1114
require_accelerate,
12-
require_big_gpu_with_torch_cuda,
15+
require_big_accelerator,
1316
require_torch_cuda_compatibility,
1417
torch_device,
1518
)
@@ -23,9 +26,11 @@
2326

2427
from ..utils import LoRALayer, get_memory_consumption_stat
2528

29+
enable_full_determinism()
30+
2631

2732
@nightly
28-
@require_big_gpu_with_torch_cuda
33+
@require_big_accelerator
2934
@require_accelerate
3035
class QuantoBaseTesterMixin:
3136
model_id = None
@@ -39,13 +44,13 @@ class QuantoBaseTesterMixin:
3944
_test_torch_compile = False
4045

4146
def setUp(self):
42-
torch.cuda.reset_peak_memory_stats()
43-
torch.cuda.empty_cache()
47+
backend_reset_peak_memory_stats(torch_device)
48+
backend_empty_cache(torch_device)
4449
gc.collect()
4550

4651
def tearDown(self):
47-
torch.cuda.reset_peak_memory_stats()
48-
torch.cuda.empty_cache()
52+
backend_reset_peak_memory_stats(torch_device)
53+
backend_empty_cache(torch_device)
4954
gc.collect()
5055

5156
def get_dummy_init_kwargs(self):
@@ -89,7 +94,7 @@ def test_keep_modules_in_fp32(self):
8994
self.model_cls._keep_in_fp32_modules = self.keep_in_fp32_module
9095

9196
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
92-
model.to("cuda")
97+
model.to(torch_device)
9398

9499
for name, module in model.named_modules():
95100
if isinstance(module, torch.nn.Linear):
@@ -107,7 +112,7 @@ def test_modules_to_not_convert(self):
107112
init_kwargs.update({"quantization_config": quantization_config})
108113

109114
model = self.model_cls.from_pretrained(**init_kwargs)
110-
model.to("cuda")
115+
model.to(torch_device)
111116

112117
for name, module in model.named_modules():
113118
if name in self.modules_to_not_convert:
@@ -122,7 +127,8 @@ def test_dtype_assignment(self):
122127

123128
with self.assertRaises(ValueError):
124129
# Tries with a `device` and `dtype`
125-
model.to(device="cuda:0", dtype=torch.float16)
130+
device_0 = f"{torch_device}:0"
131+
model.to(device=device_0, dtype=torch.float16)
126132

127133
with self.assertRaises(ValueError):
128134
# Tries with a cast
@@ -133,7 +139,7 @@ def test_dtype_assignment(self):
133139
model.half()
134140

135141
# This should work
136-
model.to("cuda")
142+
model.to(torch_device)
137143

138144
def test_serialization(self):
139145
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())

0 commit comments

Comments
 (0)