Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions tests/quantization/test_torch_compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
import torch

from diffusers import DiffusionPipeline
from diffusers.utils.testing_utils import backend_empty_cache, require_torch_gpu, slow, torch_device
from diffusers.utils.testing_utils import backend_empty_cache, require_torch_accelerator, slow, torch_device


@require_torch_gpu
@require_torch_accelerator
@slow
class QuantCompileTests:
@property
Expand Down Expand Up @@ -51,7 +51,7 @@ def _init_pipeline(self, quantization_config, torch_dtype):
return pipe

def _test_torch_compile(self, torch_dtype=torch.bfloat16):
pipe = self._init_pipeline(self.quantization_config, torch_dtype).to("cuda")
pipe = self._init_pipeline(self.quantization_config, torch_dtype).to(torch_device)
# `fullgraph=True` ensures no graph breaks
pipe.transformer.compile(fullgraph=True)

Expand All @@ -71,7 +71,7 @@ def _test_torch_compile_with_group_offload_leaf(self, torch_dtype=torch.bfloat16

pipe = self._init_pipeline(self.quantization_config, torch_dtype)
group_offload_kwargs = {
"onload_device": torch.device("cuda"),
"onload_device": torch.device(torch_device),
"offload_device": torch.device("cpu"),
"offload_type": "leaf_level",
"use_stream": use_stream,
Expand All @@ -81,7 +81,7 @@ def _test_torch_compile_with_group_offload_leaf(self, torch_dtype=torch.bfloat16
for name, component in pipe.components.items():
if name != "transformer" and isinstance(component, torch.nn.Module):
if torch.device(component.device).type == "cpu":
component.to("cuda")
component.to(torch_device)

# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
Expand Down
4 changes: 2 additions & 2 deletions tests/quantization/torchao/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def test_quantization(self):
("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
]

if TorchAoConfig._is_cuda_capability_atleast_8_9():
if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
QUANTIZATION_TYPES_TO_TEST.extend([
("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])),
("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])),
Expand Down Expand Up @@ -753,7 +753,7 @@ def test_quantization(self):
("int8dq", np.array([0.0546, 0.0761, 0.1386, 0.0488, 0.0644, 0.1425, 0.0605, 0.0742, 0.1406, 0.0625, 0.0722, 0.1523, 0.0625, 0.0742, 0.1503, 0.0605, 0.3886, 0.7968, 0.5507, 0.4492, 0.7890, 0.5351, 0.4316, 0.8007, 0.5390, 0.4179, 0.8281, 0.5820, 0.4531, 0.7812, 0.5703, 0.4921])),
]

if TorchAoConfig._is_cuda_capability_atleast_8_9():
if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
QUANTIZATION_TYPES_TO_TEST.extend([
("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])),
("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])),
Expand Down
Loading