1+ import  tempfile 
12import  unittest 
23
34import  torch 
910from  diffusers .utils  import  is_optimum_quanto_available 
1011from  diffusers .utils .testing_utils  import  (
1112    nightly ,
12-     numpy_cosine_similarity_distance ,
1313    require_accelerate ,
1414    require_big_gpu_with_torch_cuda ,
1515    torch_device ,
@@ -29,6 +29,7 @@ class QuantoBaseTesterMixin:
2929    torch_dtype  =  torch .bfloat16 
3030    expected_memory_use_in_gb  =  5 
3131    keep_in_fp32_module  =  "" 
32+     modules_to_not_convert  =  "" 
3233
3334    def  get_dummy_init_kwargs (self ):
3435        return  {"weights" : "float8" }
@@ -76,6 +77,22 @@ def test_keep_modules_in_fp32(self):
7677                    assert  module .weight .dtype  ==  torch .float32 
7778        self .model_cls ._keep_in_fp32_modules  =  _keep_in_fp32_modules 
7879
80+     def  test_modules_to_not_convert (self ):
81+         init_kwargs  =  self .get_dummy_model_init_kwargs ()
82+ 
83+         quantization_config_kwargs  =  self .get_dummy_init_kwargs ()
84+         quantization_config_kwargs .update ({"modules_to_not_convert" : self .modules_to_not_convert })
85+         quantization_config  =  QuantoConfig (** quantization_config_kwargs )
86+ 
87+         init_kwargs .update ({"quantization_config" : quantization_config })
88+ 
89+         model  =  self .model_cls .from_pretrained (** init_kwargs )
90+         model .to ("cuda" )
91+ 
92+         for  name , module  in  model .named_modules ():
93+             if  name  in  self .modules_to_not_convert :
94+                 assert  not  isinstance (module , QLinear )
95+ 
7996    def  test_dtype_assignment (self ):
8097        model  =  self .model_cls .from_pretrained (** self .get_dummy_model_init_kwargs ())
8198        assert  (model .get_memory_footprint () /  1024 ** 3 ) <  self .expected_memory_use_in_gb 
@@ -99,12 +116,35 @@ def test_dtype_assignment(self):
99116        # This should work 
100117        model .to ("cuda" )
101118
119+     def  test_serialization (self ):
120+         model  =  self .model_cls .from_pretrained (** self .get_dummy_model_init_kwargs ())
121+         inputs  =  self .get_dummy_inputs ()
122+ 
123+         model .to (torch_device )
124+         with  torch .no_grad ():
125+             model_output  =  model (** inputs )
126+ 
127+         with  tempfile .TemporaryDirectory () as  tmp_dir :
128+             model .save_pretrained (tmp_dir )
129+             saved_model  =  self .model_cls .from_pretrained (
130+                 tmp_dir ,
131+                 torch_dtype = torch .bfloat16 ,
132+             )
133+ 
134+         saved_model .to (torch_device )
135+         with  torch .no_grad ():
136+             saved_model_output  =  saved_model (** inputs )
137+ 
138+         max_diff  =  torch .abs (model_output  -  saved_model_output ).max ()
139+         assert  max_diff  <  1e-5 
140+ 
102141
103142class  FluxTransformerQuantoMixin (QuantoBaseTesterMixin ):
104143    model_id  =  "hf-internal-testing/tiny-flux-transformer" 
105144    model_cls  =  FluxTransformer2DModel 
106145    torch_dtype  =  torch .bfloat16 
107146    keep_in_fp32_module  =  "proj_out" 
147+     modules_to_not_convert  =  ["proj_out" ]
108148
109149    def  get_dummy_inputs (self ):
110150        return  {
@@ -130,14 +170,21 @@ def get_dummy_inputs(self):
130170        }
131171
132172
133- class  FluxTransformerFloat8 (FluxTransformerQuantoMixin , unittest .TestCase ):
173+ class  FluxTransformerFloat8WeightsTest (FluxTransformerQuantoMixin , unittest .TestCase ):
134174    expected_memory_use_in_gb  =  10 
135175
136176    def  get_dummy_init_kwargs (self ):
137177        return  {"weights" : "float8" }
138178
139179
140- class  FluxTransformerInt8 (FluxTransformerQuantoMixin , unittest .TestCase ):
180+ class  FluxTransformerFloat8WeightsAndActivationTest (FluxTransformerQuantoMixin , unittest .TestCase ):
181+     expected_memory_use_in_gb  =  10 
182+ 
183+     def  get_dummy_init_kwargs (self ):
184+         return  {"weights" : "float8" , "activations" : "float8" }
185+ 
186+ 
187+ class  FluxTransformerInt8WeightsTest (FluxTransformerQuantoMixin , unittest .TestCase ):
141188    expected_memory_use_in_gb  =  10 
142189
143190    def  get_dummy_init_kwargs (self ):
@@ -157,20 +204,42 @@ def test_torch_compile(self):
157204        with  torch .no_grad ():
158205            compiled_model_output  =  compiled_model (** inputs ).sample 
159206
160-         max_diff  =  numpy_cosine_similarity_distance (
161-             model_output .cpu ().flatten (), compiled_model_output .cpu ().flatten ()
162-         )
207+         max_diff  =  torch .abs (model_output  -  compiled_model_output ).max ()
208+         assert  max_diff  <  1e-4 
209+ 
210+ 
211+ class  FluxTransformerInt8WeightsAndActivationTest (FluxTransformerQuantoMixin , unittest .TestCase ):
212+     expected_memory_use_in_gb  =  10 
213+ 
214+     def  get_dummy_init_kwargs (self ):
215+         return  {"weights" : "int8" , "activations" : "int8" }
216+ 
217+     def  test_torch_compile (self ):
218+         model  =  self .model_cls .from_pretrained (** self .get_dummy_model_init_kwargs ())
219+         compiled_model  =  torch .compile (model , mode = "max-autotune" , fullgraph = True )
220+         inputs  =  self .get_dummy_inputs ()
221+ 
222+         model .to (torch_device )
223+         with  torch .no_grad ():
224+             model_output  =  model (** inputs ).sample 
225+         model .to ("cpu" )
226+ 
227+         compiled_model .to (torch_device )
228+         with  torch .no_grad ():
229+             compiled_model_output  =  compiled_model (** inputs ).sample 
230+ 
231+         max_diff  =  torch .abs (model_output  -  compiled_model_output ).max ()
163232        assert  max_diff  <  1e-4 
164233
165234
166- class  FluxTransformerInt4 (FluxTransformerQuantoMixin , unittest .TestCase ):
235+ class  FluxTransformerInt4WeightsTest (FluxTransformerQuantoMixin , unittest .TestCase ):
167236    expected_memory_use_in_gb  =  6 
168237
169238    def  get_dummy_init_kwargs (self ):
170239        return  {"weights" : "int4" }
171240
172241
173- class  FluxTransformerInt2 (FluxTransformerQuantoMixin , unittest .TestCase ):
242+ class  FluxTransformerInt2WeightsTest (FluxTransformerQuantoMixin , unittest .TestCase ):
174243    expected_memory_use_in_gb  =  6 
175244
176245    def  get_dummy_init_kwargs (self ):
0 commit comments