3030)
3131from  diffusers .models .attention_processor  import  Attention 
3232from  diffusers .utils .testing_utils  import  (
33+     enable_full_determinism ,
3334    is_torch_available ,
3435    is_torchao_available ,
3536    require_torch ,
3637    require_torch_gpu ,
3738    require_torchao_version_greater ,
39+     slow ,
3840    torch_device ,
3941)
4042
4143
44+ enable_full_determinism ()
45+ 
46+ 
4247if  is_torch_available ():
4348    import  torch 
4449    import  torch .nn  as  nn 
@@ -101,9 +106,21 @@ def test_repr(self):
101106        Check that there is no error in the repr 
102107        """ 
103108        quantization_config  =  TorchAoConfig ("int4_weight_only" , modules_to_not_convert = ["conv" ], group_size = 8 )
104-         repr (quantization_config )
105- 
106- 
109+         expected_repr  =  """TorchAoConfig { 
110+             "modules_to_not_convert": [ 
111+                 "conv" 
112+             ], 
113+             "quant_method": "torchao", 
114+             "quant_type": "int4_weight_only", 
115+             "quant_type_kwargs": { 
116+                 "group_size": 8 
117+             } 
118+         }""" .replace (" " , "" ).replace ("\n " , "" )
119+         quantization_repr  =  repr (quantization_config ).replace (" " , "" ).replace ("\n " , "" )
120+         self .assertEqual (quantization_repr , expected_repr )
121+ 
122+ 
123+ # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners 
107124@require_torch  
108125@require_torch_gpu  
109126@require_torchao_version_greater ("0.6.0" ) 
@@ -202,32 +219,44 @@ def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: L
202219        self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
203220
204221    def  test_quantization (self ):
205-         # TODO(aryan): update these values from our CI  
222+         # fmt: off  
206223        QUANTIZATION_TYPES_TO_TEST  =  [
207-             ("int4wo" , np .array ([0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ])),
208-             ("int4dq" , np .array ([0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ])),
209-             ("int8wo" , np .array ([0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ])),
210-             ("int8dq" , np .array ([0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ])),
211-             ("uint4wo" , np .array ([0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ])),
212-             ("int_a8w8" , np .array ([0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ])),
213-             ("uint_a16w7" , np .array ([0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ])),
224+             ("int4wo" , np .array ([0.4648  , 0.5234  , 0.5547  , 0.4219  , 0.4414  , 0.6445  , 0.4336  , 0.4531  , 0.5625  ])),
225+             ("int4dq" , np .array ([0.4688  , 0.5195  , 0.5547  , 0.418  , 0.4414  , 0.6406  , 0.4336  , 0.4531  , 0.5625  ])),
226+             ("int8wo" , np .array ([0.4648  , 0.5195  , 0.5547  , 0.4199  , 0.4414  , 0.6445  , 0.4316  , 0.4531  , 0.5625  ])),
227+             ("int8dq" , np .array ([0.4648  , 0.5195  , 0.5547  , 0.4199  , 0.4414  , 0.6445  , 0.4316  , 0.4531  , 0.5625  ])),
228+             ("uint4wo" , np .array ([0.4609  , 0.5234  , 0.5508  , 0.4199  , 0.4336  , 0.6406  , 0.4316  , 0.4531  , 0.5625  ])),
229+             ("int_a8w8" , np .array ([0.4648  , 0.5195  , 0.5547  , 0.4199  , 0.4414  , 0.6445  , 0.4316  , 0.4531  , 0.5625  ])),
230+             ("uint_a16w7" , np .array ([0.4648  , 0.5195  , 0.5547  , 0.4219  , 0.4414  , 0.6445  , 0.4316  , 0.4531  , 0.5625  ])),
214231        ]
215232
216233        if  TorchAoConfig ._is_cuda_capability_atleast_8_9 ():
217-             QUANTIZATION_TYPES_TO_TEST .extend (
218-                 [
219-                     ("float8wo_e5m2" , np .array ([0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ])),
220-                     ("float8wo_e4m3" , np .array ([0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ])),
221-                     ("float8dq_e4m3" , np .array ([0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ])),
222-                     ("float8dq_e4m3_tensor" , np .array ([0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ])),
223-                     ("float8dq_e4m3_row" , np .array ([0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ])),
224-                     ("fp4wo" , np .array ([0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ])),
225-                     ("fp6" , np .array ([0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ])),
226-                 ]
227-             )
234+             QUANTIZATION_TYPES_TO_TEST .extend ([
235+                 ("float8wo_e5m2" , np .array ([0.4590 , 0.5273 , 0.5547 , 0.4219 , 0.4375 , 0.6406 , 0.4316 , 0.4512 , 0.5625 ])),
236+                 ("float8wo_e4m3" , np .array ([0.4648 , 0.5234 , 0.5547 , 0.4219 , 0.4414 , 0.6406 , 0.4316 , 0.4531 , 0.5625 ])),
237+                 # ===== 
238+                 # The following lead to an internal torch error: 
239+                 #    RuntimeError: mat2 shape (32x4 must be divisible by 16 
240+                 # Skip these for now; TODO(aryan): investigate later 
241+                 # ("float8dq_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), 
242+                 # ("float8dq_e4m3_tensor", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), 
243+                 # ===== 
244+                 # Cutlass fails to initialize for below 
245+                 # ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), 
246+                 # ===== 
247+                 ("fp4" , np .array ([0.4668 , 0.5195 , 0.5547 , 0.4199 , 0.4434 , 0.6445 , 0.4316 , 0.4531 , 0.5625 ])),
248+                 ("fp6" , np .array ([0.4668 , 0.5195 , 0.5547 , 0.4199 , 0.4434 , 0.6445 , 0.4316 , 0.4531 , 0.5625 ])),
249+             ])
250+         # fmt: on 
228251
229252        for  quantization_name , expected_slice  in  QUANTIZATION_TYPES_TO_TEST :
230-             quantization_config  =  TorchAoConfig (quant_type = quantization_name )
253+             quant_kwargs  =  {}
254+             if  quantization_name  in  ["uint4wo" , "uint_a16w7" ]:
255+                 # The dummy flux model that we use requires us to impose some restrictions on group_size here 
256+                 quant_kwargs .update ({"group_size" : 16 })
257+             quantization_config  =  TorchAoConfig (
258+                 quant_type = quantization_name , modules_to_not_convert = ["x_embedder" ], ** quant_kwargs 
259+             )
231260            self ._test_quant_type (quantization_config , expected_slice )
232261
233262    def  test_int4wo_quant_bfloat16_conversion (self ):
@@ -277,10 +306,9 @@ def test_offload(self):
277306            )
278307
279308            output  =  quantized_model (** inputs )[0 ]
280- 
281309            output_slice  =  output .flatten ()[- 9 :].detach ().float ().cpu ().numpy ()
282-              # TODO(aryan): get slice from CI 
283-             expected_slice  =  np .array ([0 ,  0 , 0 ,  0 ,  0 , 0 ,  0 ,  0 , 0 ])
310+ 
311+             expected_slice  =  np .array ([0.3457  ,  - 0.0366 , 0.0105  ,  - 0.2275 ,  - 0.4941 , 0.4395  ,  - 0.166 ,  - 0.6641 , 0.4375  ])
284312            self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
285313
286314    def  test_modules_to_not_convert (self ):
@@ -333,6 +361,7 @@ def test_training(self):
333361                self .assertTrue (module .adapter [1 ].weight .grad .norm ().item () >  0 )
334362
335363    def  test_torch_compile (self ):
364+         r"""Test that verifies if torch.compile works with torchao quantization.""" 
336365        quantization_config  =  TorchAoConfig ("int8_weight_only" )
337366        components  =  self .get_dummy_components (quantization_config )
338367        pipe  =  FluxPipeline (** components )
@@ -348,7 +377,54 @@ def test_torch_compile(self):
348377        # Note: Seems to require higher tolerance 
349378        self .assertTrue (np .allclose (normal_output , compile_output , atol = 1e-2 , rtol = 1e-3 ))
350379
380+     @staticmethod  
381+     def  _get_memory_footprint (module ):
382+         quantized_param_memory  =  0.0 
383+         unquantized_param_memory  =  0.0 
384+ 
385+         for  param  in  module .parameters ():
386+             if  param .__class__ .__name__  ==  "AffineQuantizedTensor" :
387+                 data , scale , zero_point  =  param .layout_tensor .get_plain ()
388+                 quantized_param_memory  +=  data .numel () +  data .element_size ()
389+                 quantized_param_memory  +=  scale .numel () +  scale .element_size ()
390+                 quantized_param_memory  +=  zero_point .numel () +  zero_point .element_size ()
391+             else :
392+                 unquantized_param_memory  +=  param .data .numel () *  param .data .element_size ()
393+ 
394+         total_memory  =  quantized_param_memory  +  unquantized_param_memory 
395+         return  total_memory , quantized_param_memory , unquantized_param_memory 
396+ 
397+     def  test_memory_footprint (self ):
398+         r""" 
399+         A simple test to check if the model conversion has been done correctly by checking on the 
400+         memory footprint of the converted model and the class type of the linear layers of the converted models 
401+         """ 
402+         transformer_int4wo  =  self .get_dummy_components (TorchAoConfig ("int4wo" ))["transformer" ]
403+         transformer_int4wo_gs32  =  self .get_dummy_components (TorchAoConfig ("int4wo" , group_size = 32 ))["transformer" ]
404+         transformer_int8wo  =  self .get_dummy_components (TorchAoConfig ("int8wo" ))["transformer" ]
405+         transformer_bf16  =  self .get_dummy_components (None )["transformer" ]
406+ 
407+         total_int4wo , quantized_int4wo , unquantized_int4wo  =  self ._get_memory_footprint (transformer_int4wo )
408+         total_int4wo_gs32 , quantized_int4wo_gs32 , unquantized_int4wo_gs32  =  self ._get_memory_footprint (
409+             transformer_int4wo_gs32 
410+         )
411+         total_int8wo , quantized_int8wo , unquantized_int8wo  =  self ._get_memory_footprint (transformer_int8wo )
412+         total_bf16 , quantized_bf16 , unquantized_bf16  =  self ._get_memory_footprint (transformer_bf16 )
351413
414+         self .assertTrue (quantized_bf16  ==  0  and  total_bf16  ==  unquantized_bf16 )
415+         # int4wo_gs32 has smaller group size, so more groups -> more scales and zero points 
416+         self .assertTrue (total_int8wo  <  total_bf16  <  total_int4wo_gs32 )
417+         # int4 with default group size quantized very few linear layers compared to a smaller group size of 32 
418+         self .assertTrue (quantized_int4wo  <  quantized_int4wo_gs32  and  unquantized_int4wo  >  unquantized_int4wo_gs32 )
419+         # int8 quantizes more layers compare to int4 with default group size 
420+         self .assertTrue (quantized_int8wo  <  quantized_int4wo )
421+ 
422+     def  test_wrong_config (self ):
423+         with  self .assertRaises (ValueError ):
424+             self .get_dummy_components (TorchAoConfig ("int42" ))
425+ 
426+ 
427+ # This class is not to be run as a test by itself. See the tests that follow this class 
352428@require_torch  
353429@require_torch_gpu  
354430@require_torchao_version_greater ("0.6.0" ) 
@@ -371,14 +447,15 @@ def get_dummy_model(self, device=None):
371447        )
372448        return  quantized_model .to (device )
373449
374-     def  get_dummy_tensor_inputs (self , device = None ):
450+     def  get_dummy_tensor_inputs (self , device = None ,  seed :  int   =   0 ):
375451        batch_size  =  1 
376452        num_latent_channels  =  4 
377453        num_image_channels  =  3 
378454        height  =  width  =  4 
379455        sequence_length  =  48 
380456        embedding_dim  =  32 
381457
458+         torch .manual_seed (seed )
382459        hidden_states  =  torch .randn ((batch_size , height  *  width , num_latent_channels )).to (device , dtype = torch .bfloat16 )
383460        encoder_hidden_states  =  torch .randn ((batch_size , sequence_length , embedding_dim )).to (
384461            device , dtype = torch .bfloat16 
@@ -425,27 +502,112 @@ def test_serialization_expected_slice(self):
425502
426503class  TorchAoSerializationINTA8W8Test (TorchAoSerializationTest ):
427504    quant_method , quant_method_kwargs  =  "int8_dynamic_activation_int8_weight" , {}
428-     expected_slice  =  np .array ([0 ,  0 ,  0 ,  0 ,  0 , 0 ,  0 ,  0 , 0 ])
505+     expected_slice  =  np .array ([0.3633  ,  - 0.1357 ,  - 0.0188 ,  - 0.249 ,  - 0.4688 , 0.5078  ,  - 0.1289 ,  - 0.6914 , 0.4551  ])
429506    serialized_expected_slice  =  expected_slice 
430507    device  =  "cuda" 
431508
432509
433510class  TorchAoSerializationINTA16W8Test (TorchAoSerializationTest ):
434511    quant_method , quant_method_kwargs  =  "int8_weight_only" , {}
435-     expected_slice  =  np .array ([0 ,  0 ,  0 ,  0 ,  0 , 0 ,  0 ,  0 , 0 ])
512+     expected_slice  =  np .array ([0.3613  ,  - 0.127 ,  - 0.0223 ,  - 0.2539 ,  - 0.459 , 0.4961  ,  - 0.1357 ,  - 0.6992 , 0.4551  ])
436513    serialized_expected_slice  =  expected_slice 
437514    device  =  "cuda" 
438515
439516
440517class  TorchAoSerializationINTA8W8CPUTest (TorchAoSerializationTest ):
441518    quant_method , quant_method_kwargs  =  "int8_dynamic_activation_int8_weight" , {}
442-     expected_slice  =  np .array ([0 ,  0 ,  0 ,  0 ,  0 , 0 ,  0 ,  0 , 0 ])
519+     expected_slice  =  np .array ([0.3633  ,  - 0.1357 ,  - 0.0188 ,  - 0.249 ,  - 0.4688 , 0.5078  ,  - 0.1289 ,  - 0.6914 , 0.4551  ])
443520    serialized_expected_slice  =  expected_slice 
444521    device  =  "cpu" 
445522
446523
447524class  TorchAoSerializationINTA16W8CPUTest (TorchAoSerializationTest ):
448525    quant_method , quant_method_kwargs  =  "int8_weight_only" , {}
449-     expected_slice  =  np .array ([0 ,  0 ,  0 ,  0 ,  0 , 0 ,  0 ,  0 , 0 ])
526+     expected_slice  =  np .array ([0.3613  ,  - 0.127 ,  - 0.0223 ,  - 0.2539 ,  - 0.459 , 0.4961  ,  - 0.1357 ,  - 0.6992 , 0.4551  ])
450527    serialized_expected_slice  =  expected_slice 
451528    device  =  "cpu" 
529+ 
530+ 
531+ # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners 
532+ @require_torch  
533+ @require_torch_gpu  
534+ @require_torchao_version_greater ("0.6.0" ) 
535+ @slow  
536+ class  SlowTorchAoTests (unittest .TestCase ):
537+     def  tearDown (self ):
538+         gc .collect ()
539+         torch .cuda .empty_cache ()
540+ 
541+     def  get_dummy_components (self , quantization_config : TorchAoConfig ):
542+         model_id  =  "black-forest-labs/FLUX.1-dev" 
543+         transformer  =  FluxTransformer2DModel .from_pretrained (
544+             model_id ,
545+             subfolder = "transformer" ,
546+             quantization_config = quantization_config ,
547+             torch_dtype = torch .bfloat16 ,
548+         )
549+         text_encoder  =  CLIPTextModel .from_pretrained (model_id , subfolder = "text_encoder" )
550+         text_encoder_2  =  T5EncoderModel .from_pretrained (model_id , subfolder = "text_encoder_2" )
551+         tokenizer  =  CLIPTokenizer .from_pretrained (model_id , subfolder = "tokenizer" )
552+         tokenizer_2  =  AutoTokenizer .from_pretrained (model_id , subfolder = "tokenizer_2" )
553+         vae  =  AutoencoderKL .from_pretrained (model_id , subfolder = "vae" )
554+         scheduler  =  FlowMatchEulerDiscreteScheduler ()
555+ 
556+         return  {
557+             "scheduler" : scheduler ,
558+             "text_encoder" : text_encoder ,
559+             "text_encoder_2" : text_encoder_2 ,
560+             "tokenizer" : tokenizer ,
561+             "tokenizer_2" : tokenizer_2 ,
562+             "transformer" : transformer ,
563+             "vae" : vae ,
564+         }
565+ 
566+     def  get_dummy_inputs (self , device : torch .device , seed : int  =  0 ):
567+         if  str (device ).startswith ("mps" ):
568+             generator  =  torch .manual_seed (seed )
569+         else :
570+             generator  =  torch .Generator ().manual_seed (seed )
571+ 
572+         inputs  =  {
573+             "prompt" : "an astronaut riding a horse in space" ,
574+             "height" : 512 ,
575+             "width" : 512 ,
576+             "num_inference_steps" : 20 ,
577+             "output_type" : "np" ,
578+             "generator" : generator ,
579+         }
580+ 
581+         return  inputs 
582+ 
583+     def  _test_quant_type (self , quantization_config , expected_slice ):
584+         components  =  self .get_dummy_components (quantization_config )
585+         pipe  =  FluxPipeline (** components ).to (dtype = torch .bfloat16 )
586+         pipe .enable_model_cpu_offload ()
587+ 
588+         inputs  =  self .get_dummy_inputs (torch_device )
589+         output  =  pipe (** inputs )[0 ].flatten ()
590+         output_slice  =  np .concatenate ((output [:16 ], output [- 16 :]))
591+ 
592+         self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
593+ 
594+     def  test_quantization (self ):
595+         # fmt: off 
596+         QUANTIZATION_TYPES_TO_TEST  =  [
597+             ("int8wo" , np .array ([0.0505 , 0.0742 , 0.1367 , 0.0429 , 0.0585 , 0.1386 , 0.0585 , 0.0703 , 0.1367 , 0.0566 , 0.0703 , 0.1464 , 0.0546 , 0.0703 , 0.1425 , 0.0546 , 0.3535 , 0.7578 , 0.5000 , 0.4062 , 0.7656 , 0.5117 , 0.4121 , 0.7656 , 0.5117 , 0.3984 , 0.7578 , 0.5234 , 0.4023 , 0.7382 , 0.5390 , 0.4570 ])),
598+             ("int8dq" , np .array ([0.0546 , 0.0761 , 0.1386 , 0.0488 , 0.0644 , 0.1425 , 0.0605 , 0.0742 , 0.1406 , 0.0625 , 0.0722 , 0.1523 , 0.0625 , 0.0742 , 0.1503 , 0.0605 , 0.3886 , 0.7968 , 0.5507 , 0.4492 , 0.7890 , 0.5351 , 0.4316 , 0.8007 , 0.5390 , 0.4179 , 0.8281 , 0.5820 , 0.4531 , 0.7812 , 0.5703 , 0.4921 ])),
599+         ]
600+ 
601+         if  TorchAoConfig ._is_cuda_capability_atleast_8_9 ():
602+             QUANTIZATION_TYPES_TO_TEST .extend ([
603+                 ("float8wo_e4m3" , np .array ([0.0546 , 0.0722 , 0.1328 , 0.0468 , 0.0585 , 0.1367 , 0.0605 , 0.0703 , 0.1328 , 0.0625 , 0.0703 , 0.1445 , 0.0585 , 0.0703 , 0.1406 , 0.0605 , 0.3496 , 0.7109 , 0.4843 , 0.4042 , 0.7226 , 0.5000 , 0.4160 , 0.7031 , 0.4824 , 0.3886 , 0.6757 , 0.4667 , 0.3710 , 0.6679 , 0.4902 , 0.4238 ])),
604+                 ("fp5_e3m1" , np .array ([0.0527 , 0.0742 , 0.1289 , 0.0449 , 0.0625 , 0.1308 , 0.0585 , 0.0742 , 0.1269 , 0.0585 , 0.0722 , 0.1328 , 0.0566 , 0.0742 , 0.1347 , 0.0585 , 0.3691 , 0.7578 , 0.5429 , 0.4355 , 0.7695 , 0.5546 , 0.4414 , 0.7578 , 0.5468 , 0.4179 , 0.7265 , 0.5273 , 0.3945 , 0.6992 , 0.5234 , 0.4316 ])),
605+             ])
606+         # fmt: on 
607+ 
608+         for  quantization_name , expected_slice  in  QUANTIZATION_TYPES_TO_TEST :
609+             quantization_config  =  TorchAoConfig (quant_type = quantization_name , modules_to_not_convert = ["x_embedder" ])
610+             self ._test_quant_type (quantization_config , expected_slice )
611+             gc .collect ()
612+             torch .cuda .empty_cache ()
613+             torch .cuda .synchronize ()
0 commit comments