@@ -75,6 +75,7 @@ def forward(self, input, *args, **kwargs):
7575if is_torchao_available ():
7676 from torchao .dtypes import AffineQuantizedTensor
7777 from torchao .quantization .linear_activation_quantized_tensor import LinearActivationQuantizedTensor
78+ from torchao .utils import get_model_size_in_bytes
7879
7980
8081@require_torch
@@ -138,11 +139,13 @@ def get_dummy_components(self, quantization_config: TorchAoConfig):
138139 quantization_config = quantization_config ,
139140 torch_dtype = torch .bfloat16 ,
140141 )
141- text_encoder = CLIPTextModel .from_pretrained (model_id , subfolder = "text_encoder" )
142- text_encoder_2 = T5EncoderModel .from_pretrained (model_id , subfolder = "text_encoder_2" )
142+ text_encoder = CLIPTextModel .from_pretrained (model_id , subfolder = "text_encoder" , torch_dtype = torch .bfloat16 )
143+ text_encoder_2 = T5EncoderModel .from_pretrained (
144+ model_id , subfolder = "text_encoder_2" , torch_dtype = torch .bfloat16
145+ )
143146 tokenizer = CLIPTokenizer .from_pretrained (model_id , subfolder = "tokenizer" )
144147 tokenizer_2 = AutoTokenizer .from_pretrained (model_id , subfolder = "tokenizer_2" )
145- vae = AutoencoderKL .from_pretrained (model_id , subfolder = "vae" )
148+ vae = AutoencoderKL .from_pretrained (model_id , subfolder = "vae" , torch_dtype = torch . bfloat16 )
146149 scheduler = FlowMatchEulerDiscreteScheduler ()
147150
148151 return {
@@ -211,7 +214,7 @@ def get_dummy_tensor_inputs(self, device=None, seed: int = 0):
211214 def _test_quant_type (self , quantization_config : TorchAoConfig , expected_slice : List [float ]):
212215 components = self .get_dummy_components (quantization_config )
213216 pipe = FluxPipeline (** components )
214- pipe .to (device = torch_device , dtype = torch . bfloat16 )
217+ pipe .to (device = torch_device )
215218
216219 inputs = self .get_dummy_inputs (torch_device )
217220 output = pipe (** inputs )[0 ]
@@ -315,21 +318,33 @@ def test_offload(self):
315318
316319 def test_modules_to_not_convert (self ):
317320 quantization_config = TorchAoConfig ("int8_weight_only" , modules_to_not_convert = ["transformer_blocks.0" ])
318- quantized_model = FluxTransformer2DModel .from_pretrained (
321+ quantized_model_with_not_convert = FluxTransformer2DModel .from_pretrained (
319322 "hf-internal-testing/tiny-flux-pipe" ,
320323 subfolder = "transformer" ,
321324 quantization_config = quantization_config ,
322325 torch_dtype = torch .bfloat16 ,
323326 )
324327
325- unquantized_layer = quantized_model .transformer_blocks [0 ].ff .net [2 ]
328+ unquantized_layer = quantized_model_with_not_convert .transformer_blocks [0 ].ff .net [2 ]
326329 self .assertTrue (isinstance (unquantized_layer , torch .nn .Linear ))
327330 self .assertFalse (isinstance (unquantized_layer .weight , AffineQuantizedTensor ))
328331 self .assertEqual (unquantized_layer .weight .dtype , torch .bfloat16 )
329332
330- quantized_layer = quantized_model .proj_out
333+ quantized_layer = quantized_model_with_not_convert .proj_out
331334 self .assertTrue (isinstance (quantized_layer .weight , AffineQuantizedTensor ))
332- self .assertEqual (quantized_layer .weight .layout_tensor .data .dtype , torch .int8 )
335+
336+ quantization_config = TorchAoConfig ("int8_weight_only" )
337+ quantized_model = FluxTransformer2DModel .from_pretrained (
338+ "hf-internal-testing/tiny-flux-pipe" ,
339+ subfolder = "transformer" ,
340+ quantization_config = quantization_config ,
341+ torch_dtype = torch .bfloat16 ,
342+ )
343+
344+ size_quantized_with_not_convert = get_model_size_in_bytes (quantized_model_with_not_convert )
345+ size_quantized = get_model_size_in_bytes (quantized_model )
346+
347+ self .assertTrue (size_quantized < size_quantized_with_not_convert )
333348
334349 def test_training (self ):
335350 quantization_config = TorchAoConfig ("int8_weight_only" )
@@ -380,23 +395,6 @@ def test_torch_compile(self):
380395 # Note: Seems to require higher tolerance
381396 self .assertTrue (np .allclose (normal_output , compile_output , atol = 1e-2 , rtol = 1e-3 ))
382397
383- @staticmethod
384- def _get_memory_footprint (module ):
385- quantized_param_memory = 0.0
386- unquantized_param_memory = 0.0
387-
388- for param in module .parameters ():
389- if param .__class__ .__name__ == "AffineQuantizedTensor" :
390- data , scale , zero_point = param .layout_tensor .get_plain ()
391- quantized_param_memory += data .numel () + data .element_size ()
392- quantized_param_memory += scale .numel () + scale .element_size ()
393- quantized_param_memory += zero_point .numel () + zero_point .element_size ()
394- else :
395- unquantized_param_memory += param .data .numel () * param .data .element_size ()
396-
397- total_memory = quantized_param_memory + unquantized_param_memory
398- return total_memory , quantized_param_memory , unquantized_param_memory
399-
400398 def test_memory_footprint (self ):
401399 r"""
402400 A simple test to check if the model conversion has been done correctly by checking on the
@@ -407,20 +405,18 @@ def test_memory_footprint(self):
407405 transformer_int8wo = self .get_dummy_components (TorchAoConfig ("int8wo" ))["transformer" ]
408406 transformer_bf16 = self .get_dummy_components (None )["transformer" ]
409407
410- total_int4wo , quantized_int4wo , unquantized_int4wo = self ._get_memory_footprint (transformer_int4wo )
411- total_int4wo_gs32 , quantized_int4wo_gs32 , unquantized_int4wo_gs32 = self ._get_memory_footprint (
412- transformer_int4wo_gs32
413- )
414- total_int8wo , quantized_int8wo , unquantized_int8wo = self ._get_memory_footprint (transformer_int8wo )
415- total_bf16 , quantized_bf16 , unquantized_bf16 = self ._get_memory_footprint (transformer_bf16 )
416-
417- self .assertTrue (quantized_bf16 == 0 and total_bf16 == unquantized_bf16 )
418- # int4wo_gs32 has smaller group size, so more groups -> more scales and zero points
419- self .assertTrue (total_int8wo < total_bf16 < total_int4wo_gs32 )
420- # int4 with default group size quantized very few linear layers compared to a smaller group size of 32
421- self .assertTrue (quantized_int4wo < quantized_int4wo_gs32 and unquantized_int4wo > unquantized_int4wo_gs32 )
408+ total_int4wo = get_model_size_in_bytes (transformer_int4wo )
409+ total_int4wo_gs32 = get_model_size_in_bytes (transformer_int4wo_gs32 )
410+ total_int8wo = get_model_size_in_bytes (transformer_int8wo )
411+ total_bf16 = get_model_size_in_bytes (transformer_bf16 )
412+
413+ # Latter has smaller group size, so more groups -> more scales and zero points
414+ self .assertTrue (total_int4wo < total_int4wo_gs32 )
422415 # int8 quantizes more layers compare to int4 with default group size
423- self .assertTrue (quantized_int8wo < quantized_int4wo )
416+ self .assertTrue (total_int8wo < total_int4wo )
417+ # int4wo does not quantize too many layers because of default group size, but for the layers it does
418+ # there is additional overhead of scales and zero points
419+ self .assertTrue (total_bf16 < total_int4wo )
424420
425421 def test_wrong_config (self ):
426422 with self .assertRaises (ValueError ):
@@ -555,11 +551,13 @@ def get_dummy_components(self, quantization_config: TorchAoConfig):
555551 quantization_config = quantization_config ,
556552 torch_dtype = torch .bfloat16 ,
557553 )
558- text_encoder = CLIPTextModel .from_pretrained (model_id , subfolder = "text_encoder" )
559- text_encoder_2 = T5EncoderModel .from_pretrained (model_id , subfolder = "text_encoder_2" )
554+ text_encoder = CLIPTextModel .from_pretrained (model_id , subfolder = "text_encoder" , torch_dtype = torch .bfloat16 )
555+ text_encoder_2 = T5EncoderModel .from_pretrained (
556+ model_id , subfolder = "text_encoder_2" , torch_dtype = torch .bfloat16
557+ )
560558 tokenizer = CLIPTokenizer .from_pretrained (model_id , subfolder = "tokenizer" )
561559 tokenizer_2 = AutoTokenizer .from_pretrained (model_id , subfolder = "tokenizer_2" )
562- vae = AutoencoderKL .from_pretrained (model_id , subfolder = "vae" )
560+ vae = AutoencoderKL .from_pretrained (model_id , subfolder = "vae" , torch_dtype = torch . bfloat16 )
563561 scheduler = FlowMatchEulerDiscreteScheduler ()
564562
565563 return {
@@ -591,7 +589,7 @@ def get_dummy_inputs(self, device: torch.device, seed: int = 0):
591589
592590 def _test_quant_type (self , quantization_config , expected_slice ):
593591 components = self .get_dummy_components (quantization_config )
594- pipe = FluxPipeline (** components ). to ( dtype = torch . bfloat16 )
592+ pipe = FluxPipeline (** components )
595593 pipe .enable_model_cpu_offload ()
596594
597595 inputs = self .get_dummy_inputs (torch_device )
0 commit comments