66from diffusers .models .attention_processor import Attention
77from diffusers .utils import is_optimum_quanto_available , is_torch_available
88from diffusers .utils .testing_utils import (
9+ enable_full_determinism ,
910 nightly ,
1011 numpy_cosine_similarity_distance ,
1112 require_accelerate ,
13+ require_big_accelerator ,
1214 require_big_gpu_with_torch_cuda ,
1315 require_torch_cuda_compatibility ,
1416 torch_device ,
2325
2426 from ..utils import LoRALayer , get_memory_consumption_stat
2527
28+ enable_full_determinism ()
29+
2630
2731@nightly
28- @require_big_gpu_with_torch_cuda
32+ @require_big_accelerator
2933@require_accelerate
3034class QuantoBaseTesterMixin :
3135 model_id = None
@@ -37,15 +41,17 @@ class QuantoBaseTesterMixin:
3741 keep_in_fp32_module = ""
3842 modules_to_not_convert = ""
3943 _test_torch_compile = False
44+ torch_accelerator_module = None
4045
4146 def setUp (self ):
42- torch .cuda .reset_peak_memory_stats ()
43- torch .cuda .empty_cache ()
47+ self .torch_accelerator_module = getattr (torch , torch_device , torch .cuda )
48+ self .torch_accelerator_module .reset_peak_memory_stats ()
49+ self .torch_accelerator_module .empty_cache ()
4450 gc .collect ()
4551
4652 def tearDown (self ):
47- torch . cuda .reset_peak_memory_stats ()
48- torch . cuda .empty_cache ()
53+ self . torch_accelerator_module .reset_peak_memory_stats ()
54+ self . torch_accelerator_module .empty_cache ()
4955 gc .collect ()
5056
5157 def get_dummy_init_kwargs (self ):
@@ -89,7 +95,7 @@ def test_keep_modules_in_fp32(self):
8995 self .model_cls ._keep_in_fp32_modules = self .keep_in_fp32_module
9096
9197 model = self .model_cls .from_pretrained (** self .get_dummy_model_init_kwargs ())
92- model .to ("cuda" )
98+ model .to (torch_device )
9399
94100 for name , module in model .named_modules ():
95101 if isinstance (module , torch .nn .Linear ):
@@ -107,7 +113,7 @@ def test_modules_to_not_convert(self):
107113 init_kwargs .update ({"quantization_config" : quantization_config })
108114
109115 model = self .model_cls .from_pretrained (** init_kwargs )
110- model .to ("cuda" )
116+ model .to (torch_device )
111117
112118 for name , module in model .named_modules ():
113119 if name in self .modules_to_not_convert :
@@ -122,7 +128,8 @@ def test_dtype_assignment(self):
122128
123129 with self .assertRaises (ValueError ):
124130 # Tries with a `device` and `dtype`
125- model .to (device = "cuda:0" , dtype = torch .float16 )
131+ device_0 = f"{ torch_device } :0"
132+ model .to (device = device_0 , dtype = torch .float16 )
126133
127134 with self .assertRaises (ValueError ):
128135 # Tries with a cast
@@ -133,7 +140,7 @@ def test_dtype_assignment(self):
133140 model .half ()
134141
135142 # This should work
136- model .to ("cuda" )
143+ model .to (torch_device )
137144
138145 def test_serialization (self ):
139146 model = self .model_cls .from_pretrained (** self .get_dummy_model_init_kwargs ())
0 commit comments