66from diffusers .models .attention_processor import Attention
77from diffusers .utils import is_optimum_quanto_available , is_torch_available
88from diffusers .utils .testing_utils import (
9+ backend_empty_cache ,
10+ backend_reset_peak_memory_stats ,
11+ enable_full_determinism ,
912 nightly ,
1013 numpy_cosine_similarity_distance ,
1114 require_accelerate ,
12- require_big_gpu_with_torch_cuda ,
15+ require_big_accelerator ,
1316 require_torch_cuda_compatibility ,
1417 torch_device ,
1518)
2326
2427 from ..utils import LoRALayer , get_memory_consumption_stat
2528
29+ enable_full_determinism ()
30+
2631
2732@nightly
28- @require_big_gpu_with_torch_cuda
33+ @require_big_accelerator
2934@require_accelerate
3035class QuantoBaseTesterMixin :
3136 model_id = None
@@ -39,13 +44,13 @@ class QuantoBaseTesterMixin:
3944 _test_torch_compile = False
4045
4146 def setUp (self ):
42- torch . cuda . reset_peak_memory_stats ( )
43- torch . cuda . empty_cache ( )
47+ backend_reset_peak_memory_stats ( torch_device )
48+ backend_empty_cache ( torch_device )
4449 gc .collect ()
4550
4651 def tearDown (self ):
47- torch . cuda . reset_peak_memory_stats ( )
48- torch . cuda . empty_cache ( )
52+ backend_reset_peak_memory_stats ( torch_device )
53+ backend_empty_cache ( torch_device )
4954 gc .collect ()
5055
5156 def get_dummy_init_kwargs (self ):
@@ -89,7 +94,7 @@ def test_keep_modules_in_fp32(self):
8994 self .model_cls ._keep_in_fp32_modules = self .keep_in_fp32_module
9095
9196 model = self .model_cls .from_pretrained (** self .get_dummy_model_init_kwargs ())
92- model .to ("cuda" )
97+ model .to (torch_device )
9398
9499 for name , module in model .named_modules ():
95100 if isinstance (module , torch .nn .Linear ):
@@ -107,7 +112,7 @@ def test_modules_to_not_convert(self):
107112 init_kwargs .update ({"quantization_config" : quantization_config })
108113
109114 model = self .model_cls .from_pretrained (** init_kwargs )
110- model .to ("cuda" )
115+ model .to (torch_device )
111116
112117 for name , module in model .named_modules ():
113118 if name in self .modules_to_not_convert :
@@ -122,7 +127,8 @@ def test_dtype_assignment(self):
122127
123128 with self .assertRaises (ValueError ):
124129 # Tries with a `device` and `dtype`
125- model .to (device = "cuda:0" , dtype = torch .float16 )
130+ device_0 = f"{ torch_device } :0"
131+ model .to (device = device_0 , dtype = torch .float16 )
126132
127133 with self .assertRaises (ValueError ):
128134 # Tries with a cast
@@ -133,7 +139,7 @@ def test_dtype_assignment(self):
133139 model .half ()
134140
135141 # This should work
136- model .to ("cuda" )
142+ model .to (torch_device )
137143
138144 def test_serialization (self ):
139145 model = self .model_cls .from_pretrained (** self .get_dummy_model_init_kwargs ())
0 commit comments