Skip to content

Commit 258e2b7

Browse files
committed
enable group_offload cases and quanto cases on XPU
Signed-off-by: YAO Matrix <[email protected]>
1 parent b4be422 commit 258e2b7

File tree

2 files changed

+18
-10
lines changed

2 files changed

+18
-10
lines changed

tests/pipelines/test_pipelines_common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
require_accelerator,
5454
require_hf_hub_version_greater,
5555
require_torch,
56+
require_torch_accelerator,
5657
require_torch_gpu,
5758
require_transformers_version_greater,
5859
skip_mps,
@@ -2210,7 +2211,7 @@ def test_layerwise_casting_inference(self):
22102211
inputs = self.get_dummy_inputs(torch_device)
22112212
_ = pipe(**inputs)[0]
22122213

2213-
@require_torch_gpu
2214+
@require_torch_accelerator
22142215
def test_group_offloading_inference(self):
22152216
if not self.test_group_offloading:
22162217
return

tests/quantization/quanto/test_quanto.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
from diffusers.models.attention_processor import Attention
77
from diffusers.utils import is_optimum_quanto_available, is_torch_available
88
from diffusers.utils.testing_utils import (
9+
enable_full_determinism,
910
nightly,
1011
numpy_cosine_similarity_distance,
1112
require_accelerate,
13+
require_big_accelerator,
1214
require_big_gpu_with_torch_cuda,
1315
require_torch_cuda_compatibility,
1416
torch_device,
@@ -23,9 +25,11 @@
2325

2426
from ..utils import LoRALayer, get_memory_consumption_stat
2527

28+
enable_full_determinism()
29+
2630

2731
@nightly
28-
@require_big_gpu_with_torch_cuda
32+
@require_big_accelerator
2933
@require_accelerate
3034
class QuantoBaseTesterMixin:
3135
model_id = None
@@ -37,15 +41,17 @@ class QuantoBaseTesterMixin:
3741
keep_in_fp32_module = ""
3842
modules_to_not_convert = ""
3943
_test_torch_compile = False
44+
torch_accelerator_module = None
4045

4146
def setUp(self):
42-
torch.cuda.reset_peak_memory_stats()
43-
torch.cuda.empty_cache()
47+
self.torch_accelerator_module = getattr(torch, torch_device, torch.cuda)
48+
self.torch_accelerator_module.reset_peak_memory_stats()
49+
self.torch_accelerator_module.empty_cache()
4450
gc.collect()
4551

4652
def tearDown(self):
47-
torch.cuda.reset_peak_memory_stats()
48-
torch.cuda.empty_cache()
53+
self.torch_accelerator_module.reset_peak_memory_stats()
54+
self.torch_accelerator_module.empty_cache()
4955
gc.collect()
5056

5157
def get_dummy_init_kwargs(self):
@@ -89,7 +95,7 @@ def test_keep_modules_in_fp32(self):
8995
self.model_cls._keep_in_fp32_modules = self.keep_in_fp32_module
9096

9197
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
92-
model.to("cuda")
98+
model.to(torch_device)
9399

94100
for name, module in model.named_modules():
95101
if isinstance(module, torch.nn.Linear):
@@ -107,7 +113,7 @@ def test_modules_to_not_convert(self):
107113
init_kwargs.update({"quantization_config": quantization_config})
108114

109115
model = self.model_cls.from_pretrained(**init_kwargs)
110-
model.to("cuda")
116+
model.to(torch_device)
111117

112118
for name, module in model.named_modules():
113119
if name in self.modules_to_not_convert:
@@ -122,7 +128,8 @@ def test_dtype_assignment(self):
122128

123129
with self.assertRaises(ValueError):
124130
# Tries with a `device` and `dtype`
125-
model.to(device="cuda:0", dtype=torch.float16)
131+
device_0 = f"{torch_device}:0"
132+
model.to(device=device_0, dtype=torch.float16)
126133

127134
with self.assertRaises(ValueError):
128135
# Tries with a cast
@@ -133,7 +140,7 @@ def test_dtype_assignment(self):
133140
model.half()
134141

135142
# This should work
136-
model.to("cuda")
143+
model.to(torch_device)
137144

138145
def test_serialization(self):
139146
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())

0 commit comments

Comments
 (0)