Skip to content

Commit 2e11423

Browse files
committed
enable torchao cases on XPU
Signed-off-by: Matrix YAO <[email protected]>
1 parent be2fb77 commit 2e11423

File tree

3 files changed

+27
-25
lines changed

3 files changed

+27
-25
lines changed

src/diffusers/quantizers/quantization_config.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -656,13 +656,13 @@ def generate_fpx_quantization_types(bits: int):
656656

657657
@staticmethod
658658
def _is_cuda_capability_atleast_8_9() -> bool:
659-
if not torch.cuda.is_available():
660-
raise RuntimeError("TorchAO requires a CUDA compatible GPU and installation of PyTorch.")
661-
662-
major, minor = torch.cuda.get_device_capability()
663-
if major == 8:
664-
return minor >= 9
665-
return major >= 9
659+
if torch.cuda.is_available():
660+
major, minor = torch.cuda.get_device_capability()
661+
if major == 8:
662+
return minor >= 9
663+
return major >= 9
664+
else:
665+
return True
666666

667667
def get_apply_tensor_subclass(self):
668668
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()

tests/models/unets/test_models_unet_2d_condition.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
require_peft_backend,
4747
require_torch_accelerator,
4848
require_torch_accelerator_with_fp16,
49-
require_torch_gpu,
5049
skip_mps,
5150
slow,
5251
torch_all_close,
@@ -980,13 +979,13 @@ def test_ip_adapter_plus(self):
980979
assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4)
981980
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
982981

983-
@require_torch_gpu
984982
@parameterized.expand(
985983
[
986984
("hf-internal-testing/unet2d-sharded-dummy", None),
987985
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
988986
]
989987
)
988+
@require_torch_accelerator
990989
def test_load_sharded_checkpoint_from_hub(self, repo_id, variant):
991990
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
992991
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant)
@@ -996,13 +995,13 @@ def test_load_sharded_checkpoint_from_hub(self, repo_id, variant):
996995
assert loaded_model
997996
assert new_output.sample.shape == (4, 4, 16, 16)
998997

999-
@require_torch_gpu
1000998
@parameterized.expand(
1001999
[
10021000
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
10031001
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
10041002
]
10051003
)
1004+
@require_torch_accelerator
10061005
def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant):
10071006
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
10081007
loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant)

tests/quantization/torchao/test_torchao.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,15 @@
3030
)
3131
from diffusers.models.attention_processor import Attention
3232
from diffusers.utils.testing_utils import (
33+
backend_empty_cache,
34+
backend_synchronize,
3335
enable_full_determinism,
3436
is_torch_available,
3537
is_torchao_available,
3638
nightly,
3739
numpy_cosine_similarity_distance,
3840
require_torch,
41+
require_torch_accelerator,
3942
require_torch_gpu,
4043
require_torchao_version_greater_or_equal,
4144
slow,
@@ -61,7 +64,7 @@
6164

6265

6366
@require_torch
64-
@require_torch_gpu
67+
@require_torch_accelerator
6568
@require_torchao_version_greater_or_equal("0.7.0")
6669
class TorchAoConfigTest(unittest.TestCase):
6770
def test_to_dict(self):
@@ -119,12 +122,12 @@ def test_repr(self):
119122

120123
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
121124
@require_torch
122-
@require_torch_gpu
125+
@require_torch_accelerator
123126
@require_torchao_version_greater_or_equal("0.7.0")
124127
class TorchAoTest(unittest.TestCase):
125128
def tearDown(self):
126129
gc.collect()
127-
torch.cuda.empty_cache()
130+
backend_empty_cache(torch_device)
128131

129132
def get_dummy_components(
130133
self, quantization_config: TorchAoConfig, model_id: str = "hf-internal-testing/tiny-flux-pipe"
@@ -518,14 +521,14 @@ def test_sequential_cpu_offload(self):
518521

519522
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
520523
@require_torch
521-
@require_torch_gpu
524+
@require_torch_accelerator
522525
@require_torchao_version_greater_or_equal("0.7.0")
523526
class TorchAoSerializationTest(unittest.TestCase):
524527
model_name = "hf-internal-testing/tiny-flux-pipe"
525528

526529
def tearDown(self):
527530
gc.collect()
528-
torch.cuda.empty_cache()
531+
backend_empty_cache(torch_device)
529532

530533
def get_dummy_model(self, quant_method, quant_method_kwargs, device=None):
531534
quantization_config = TorchAoConfig(quant_method, **quant_method_kwargs)
@@ -596,14 +599,14 @@ def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs,
596599
def test_int_a8w8_cuda(self):
597600
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
598601
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
599-
device = "cuda"
602+
device = torch_device
600603
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
601604
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
602605

603606
def test_int_a16w8_cuda(self):
604607
quant_method, quant_method_kwargs = "int8_weight_only", {}
605608
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
606-
device = "cuda"
609+
device = torch_device
607610
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
608611
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
609612

@@ -624,14 +627,14 @@ def test_int_a16w8_cpu(self):
624627

625628
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
626629
@require_torch
627-
@require_torch_gpu
630+
@require_torch_accelerator
628631
@require_torchao_version_greater_or_equal("0.7.0")
629632
@slow
630633
@nightly
631634
class SlowTorchAoTests(unittest.TestCase):
632635
def tearDown(self):
633636
gc.collect()
634-
torch.cuda.empty_cache()
637+
backend_empty_cache(torch_device)
635638

636639
def get_dummy_components(self, quantization_config: TorchAoConfig):
637640
# This is just for convenience, so that we can modify it at one place for custom environments and locally testing
@@ -713,8 +716,8 @@ def test_quantization(self):
713716
quantization_config = TorchAoConfig(quant_type=quantization_name, modules_to_not_convert=["x_embedder"])
714717
self._test_quant_type(quantization_config, expected_slice)
715718
gc.collect()
716-
torch.cuda.empty_cache()
717-
torch.cuda.synchronize()
719+
backend_empty_cache(torch_device)
720+
backend_synchronize(torch_device)
718721

719722
def test_serialization_int8wo(self):
720723
quantization_config = TorchAoConfig("int8wo")
@@ -733,8 +736,8 @@ def test_serialization_int8wo(self):
733736
pipe.remove_all_hooks()
734737
del pipe.transformer
735738
gc.collect()
736-
torch.cuda.empty_cache()
737-
torch.cuda.synchronize()
739+
backend_empty_cache(torch_device)
740+
backend_synchronize(torch_device)
738741
transformer = FluxTransformer2DModel.from_pretrained(
739742
tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False
740743
)
@@ -783,14 +786,14 @@ def test_memory_footprint_int8wo(self):
783786

784787

785788
@require_torch
786-
@require_torch_gpu
789+
@require_torch_accelerator
787790
@require_torchao_version_greater_or_equal("0.7.0")
788791
@slow
789792
@nightly
790793
class SlowTorchAoPreserializedModelTests(unittest.TestCase):
791794
def tearDown(self):
792795
gc.collect()
793-
torch.cuda.empty_cache()
796+
backend_empty_cache(torch_device)
794797

795798
def get_dummy_inputs(self, device: torch.device, seed: int = 0):
796799
if str(device).startswith("mps"):

0 commit comments

Comments
 (0)