2626 FlowMatchEulerDiscreteScheduler ,
2727 FluxPipeline ,
2828 FluxTransformer2DModel ,
29- TorchAoConfig ,
3029)
30+
3131from diffusers .models .attention_processor import Attention
3232from diffusers .utils .testing_utils import (
3333 enable_full_determinism ,
3636 numpy_cosine_similarity_distance ,
3737 require_torch ,
3838 require_torch_gpu ,
39- require_torchao_version_greater_or_equal ,
4039 slow ,
4140 torch_device ,
4241)
5251 import torch
5352 import torch .nn as nn
5453
55- from . .utils import LoRALayer , get_memory_consumption_stat
54+ from tests . quantization .utils import LoRALayer , get_memory_consumption_stat
5655
5756
5857@require_torch
@@ -95,14 +94,23 @@ def get_dummy_components(
9594 subfolder = "transformer" ,
9695 quantization_config = quantization_config ,
9796 torch_dtype = torch .bfloat16 ,
97+ device_map = torch_device ,
98+ )
99+ text_encoder = CLIPTextModel .from_pretrained (
100+ model_id , subfolder = "text_encoder" , torch_dtype = torch .bfloat16
98101 )
99- text_encoder = CLIPTextModel .from_pretrained (model_id , subfolder = "text_encoder" , torch_dtype = torch .bfloat16 )
100102 text_encoder_2 = T5EncoderModel .from_pretrained (
101103 model_id , subfolder = "text_encoder_2" , torch_dtype = torch .bfloat16
102104 )
103- tokenizer = CLIPTokenizer .from_pretrained (model_id , subfolder = "tokenizer" )
104- tokenizer_2 = AutoTokenizer .from_pretrained (model_id , subfolder = "tokenizer_2" )
105- vae = AutoencoderKL .from_pretrained (model_id , subfolder = "vae" , torch_dtype = torch .bfloat16 )
105+ tokenizer = CLIPTokenizer .from_pretrained (
106+ model_id , subfolder = "tokenizer"
107+ )
108+ tokenizer_2 = AutoTokenizer .from_pretrained (
109+ model_id , subfolder = "tokenizer_2"
110+ )
111+ vae = AutoencoderKL .from_pretrained (
112+ model_id , subfolder = "vae" , torch_dtype = torch .bfloat16
113+ )
106114 scheduler = FlowMatchEulerDiscreteScheduler ()
107115
108116 return {
@@ -195,13 +203,18 @@ def _test_quantization_output(self, quantization_config: FinegrainedFP8Config, e
195203 self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
196204
197205 def test_quantization (self ):
198- expected_slice = [np .array ([0.34179688 , - 0.03613281 , 0.01428223 , - 0.22949219 , - 0.49609375 , 0.4375 , - 0.1640625 , - 0.66015625 , 0.43164062 ]), np .array ([0.3633 , - 0.1357 , - 0.0188 , - 0.249 , - 0.4688 , 0.5078 , - 0.1289 , - 0.6914 , 0.4551 ])]
206+ expected_slice = [
207+ np .array ([0.46679688 , 0.51953125 , 0.5546875 , 0.421875 , 0.44140625 , 0.64453125 , 0.43359375 , 0.453125 , 0.5625 ]),
208+ np .array ([0.46679688 , 0.51953125 , 0.5546875 , 0.421875 , 0.44140625 , 0.64453125 , 0.43359375 , 0.453125 , 0.5625 ])
209+ ]
210+
199211 for index , model_id in enumerate (["hf-internal-testing/tiny-flux-pipe" , "hf-internal-testing/tiny-flux-sharded" ]):
200212 quantization_config = FinegrainedFP8Config (
201213 modules_to_not_convert = ["x_embedder" , "proj_out" ],
202214 weight_block_size = (32 , 32 )
203215 )
204- self ._test_quantization_output (quantization_config , model_id , expected_slice [index ])
216+
217+ self ._test_quantization_output (quantization_config , expected_slice [index ], model_id )
205218
206219 def test_dtype (self ):
207220 """
@@ -216,13 +229,18 @@ def test_dtype(self):
216229 subfolder = "transformer" ,
217230 quantization_config = quantization_config ,
218231 torch_dtype = torch .bfloat16 ,
232+ device_map = torch_device ,
219233 )
220234
235+ layer = quantized_model .transformer_blocks [0 ].ff .net [2 ]
221236 weight = quantized_model .transformer_blocks [0 ].ff .net [2 ].weight
222237 weight_scale_inv = quantized_model .transformer_blocks [0 ].ff .net [2 ].weight_scale_inv
223- self .assertTrue (isinstance (weight , FP8Linear ))
238+
239+ self .assertTrue (isinstance (layer , FP8Linear ))
240+
224241 self .assertEqual (weight_scale_inv .dtype , torch .bfloat16 )
225- self .assertEqual (weight .weight .dtype , torch .float8_e4m3fn )
242+
243+ self .assertEqual (weight .dtype , torch .float8_e4m3fn )
226244
227245 def test_device_map_auto (self ):
228246 """
@@ -232,16 +250,16 @@ def test_device_map_auto(self):
232250 inputs = self .get_dummy_tensor_inputs (torch_device )
233251 # requires with different expected slices since models are different due to offload (we don't quantize modules offloaded to cpu/disk)
234252 expected_slice_auto = np .array (
235- [
236- 0.34179688 ,
237- - 0.03613281 ,
238- 0.01428223 ,
239- - 0.22949219 ,
240- - 0.49609375 ,
253+ [
254+ 0.34375 ,
255+ - 0.0402832 ,
256+ 0.01226807 ,
257+ - 0.22851562 ,
258+ - 0.49414062 ,
241259 0.4375 ,
242- - 0.1640625 ,
260+ - 0.16992188 ,
243261 - 0.66015625 ,
244- 0.43164062 ,
262+ 0.43164062
245263 ]
246264 )
247265
@@ -259,6 +277,7 @@ def test_device_map_auto(self):
259277
260278 output = quantized_model (** inputs )[0 ]
261279 output_slice = output .flatten ()[- 9 :].detach ().float ().cpu ().numpy ()
280+
262281 self .assertTrue (numpy_cosine_similarity_distance (output_slice , expected_slice_auto ) < 1e-3 )
263282
264283
@@ -269,6 +288,7 @@ def test_modules_to_not_convert(self):
269288 subfolder = "transformer" ,
270289 quantization_config = quantization_config ,
271290 torch_dtype = torch .bfloat16 ,
291+ device_map = torch_device ,
272292 )
273293
274294 unquantized_layer = quantized_model_with_not_convert .transformer_blocks [0 ].ff .net [2 ]
@@ -284,6 +304,7 @@ def test_modules_to_not_convert(self):
284304 subfolder = "transformer" ,
285305 quantization_config = quantization_config ,
286306 torch_dtype = torch .bfloat16 ,
307+ device_map = torch_device ,
287308 )
288309
289310 size_quantized_with_not_convert = self .get_model_size_in_bytes (quantized_model_with_not_convert )
@@ -298,7 +319,8 @@ def test_training(self):
298319 subfolder = "transformer" ,
299320 quantization_config = quantization_config ,
300321 torch_dtype = torch .bfloat16 ,
301- ).to (torch_device )
322+ device_map = torch_device ,
323+ )
302324
303325 for param in quantized_model .parameters ():
304326 # freeze the model as only adapter layers will be trained
@@ -321,10 +343,10 @@ def test_training(self):
321343 if isinstance (module , LoRALayer ):
322344 self .assertTrue (module .adapter [1 ].weight .grad is not None )
323345 self .assertTrue (module .adapter [1 ].weight .grad .norm ().item () > 0 )
324-
346+
325347 @nightly
326348 def test_torch_compile (self ):
327- r"""Test that verifies if torch.compile works with torchao quantization."""
349+ r"""Test that verifies if torch.compile works with fp8 quantization."""
328350 for model_id in ["hf-internal-testing/tiny-flux-pipe" , "hf-internal-testing/tiny-flux-sharded" ]:
329351 quantization_config = FinegrainedFP8Config (weight_block_size = (32 , 32 ), modules_to_not_convert = ["x_embedder" , "proj_out" ])
330352 components = self .get_dummy_components (quantization_config , model_id = model_id )
@@ -350,9 +372,14 @@ def test_memory_footprint(self):
350372 transformer_quantized = self .get_dummy_components (FinegrainedFP8Config (weight_block_size = (32 , 32 ), modules_to_not_convert = ["x_embedder" , "proj_out" ]), model_id = model_id )["transformer" ]
351373 transformer_bf16 = self .get_dummy_components (None , model_id = model_id )["transformer" ]
352374
353- for name , module in transformer_quantized .named_modules ():
354- if isinstance (module , nn .Linear ) and name not in ["x_embedder" , "proj_out" ]:
355- self .assertTrue (isinstance (module .weight , FP8Linear ))
375+ for (name , module_quantized ), (name_bf16 , module_bf16 ) in zip (transformer_quantized .named_modules (), transformer_bf16 .named_modules ()):
376+ if isinstance (module_bf16 , nn .Linear ) and name_bf16 .split ("." )[- 1 ] not in ["x_embedder" , "proj_out" ]:
377+
378+ self .assertTrue (isinstance (module_quantized , FP8Linear ))
379+
380+ self .assertEqual (module_quantized .weight .shape , module_bf16 .weight .shape )
381+
382+ self .assertEqual (name , name_bf16 )
356383
357384
358385 total_quantized = self .get_model_size_in_bytes (transformer_quantized )
@@ -373,7 +400,9 @@ def test_model_memory_usage(self):
373400
374401 transformer_quantized = self .get_dummy_components (FinegrainedFP8Config (weight_block_size = (32 , 32 ), modules_to_not_convert = ["x_embedder" , "proj_out" ]), model_id = model_id )["transformer" ]
375402 transformer_quantized .to (torch_device )
403+
376404 quantized_model_memory = get_memory_consumption_stat (transformer_quantized , inputs )
405+
377406 self .assertTrue (unquantized_model_memory / quantized_model_memory >= expected_memory_saving_ratio )
378407
379408 def test_exception_of_cpu_in_device_map (self ):
@@ -413,6 +442,7 @@ def get_dummy_model(self, device=None):
413442 subfolder = "transformer" ,
414443 quantization_config = quantization_config ,
415444 torch_dtype = torch .bfloat16 ,
445+ device_map = torch_device ,
416446 )
417447 return quantized_model .to (device )
418448
@@ -448,6 +478,7 @@ def _test_original_model_expected_slice(self, expected_slice):
448478 inputs = self .get_dummy_tensor_inputs (torch_device )
449479 output = quantized_model (** inputs )[0 ]
450480 output_slice = output .flatten ()[- 9 :].detach ().float ().cpu ().numpy ()
481+
451482 self .assertTrue (numpy_cosine_similarity_distance (output_slice , expected_slice ) < 1e-3 )
452483
453484 def _check_serialization_expected_slice (self , expected_slice , device ):
@@ -463,18 +494,34 @@ def _check_serialization_expected_slice(self, expected_slice, device):
463494 output = loaded_quantized_model (** inputs )[0 ]
464495
465496 output_slice = output .flatten ()[- 9 :].detach ().float ().cpu ().numpy ()
497+
466498 self .assertTrue (numpy_cosine_similarity_distance (output_slice , expected_slice ) < 1e-3 )
467499
468500 def test_slice_output (self ):
469- expected_slice = np .array ([0.3633 , - 0.1357 , - 0.0188 , - 0.249 , - 0.4688 , 0.5078 , - 0.1289 , - 0.6914 , 0.4551 ])
470- device = "cuda"
501+ expected_slice = np .array (
502+ [
503+ 0.34960938 ,
504+ - 0.12109375 ,
505+ - 0.02648926 ,
506+ - 0.25195312 ,
507+ - 0.45898438 ,
508+ 0.49609375 ,
509+ - 0.14453125 ,
510+ - 0.69921875 ,
511+ 0.44921875
512+ ]
513+ )
514+
515+ device = torch_device
516+
471517 self ._test_original_model_expected_slice (expected_slice )
518+
472519 self ._check_serialization_expected_slice (expected_slice , device )
473520
474521@require_torch
475522@require_torch_gpu
476523@slow
477- @ nightly
524+
478525class SlowTorchAoTests (unittest .TestCase ):
479526 def tearDown (self ):
480527 gc .collect ()
@@ -490,16 +537,23 @@ def get_dummy_components(self, quantization_config: FinegrainedFP8Config):
490537 quantization_config = quantization_config ,
491538 torch_dtype = torch .bfloat16 ,
492539 cache_dir = cache_dir ,
540+ device_map = torch_device ,
493541 )
494542 text_encoder = CLIPTextModel .from_pretrained (
495543 model_id , subfolder = "text_encoder" , torch_dtype = torch .bfloat16 , cache_dir = cache_dir
496544 )
497545 text_encoder_2 = T5EncoderModel .from_pretrained (
498546 model_id , subfolder = "text_encoder_2" , torch_dtype = torch .bfloat16 , cache_dir = cache_dir
499547 )
500- tokenizer = CLIPTokenizer .from_pretrained (model_id , subfolder = "tokenizer" , cache_dir = cache_dir )
501- tokenizer_2 = AutoTokenizer .from_pretrained (model_id , subfolder = "tokenizer_2" , cache_dir = cache_dir )
502- vae = AutoencoderKL .from_pretrained (model_id , subfolder = "vae" , torch_dtype = torch .bfloat16 , cache_dir = cache_dir )
548+ tokenizer = CLIPTokenizer .from_pretrained (
549+ model_id , subfolder = "tokenizer" , cache_dir = cache_dir
550+ )
551+ tokenizer_2 = AutoTokenizer .from_pretrained (
552+ model_id , subfolder = "tokenizer_2" , cache_dir = cache_dir
553+ )
554+ vae = AutoencoderKL .from_pretrained (
555+ model_id , subfolder = "vae" , torch_dtype = torch .bfloat16 , cache_dir = cache_dir
556+ )
503557 scheduler = FlowMatchEulerDiscreteScheduler ()
504558
505559 return {
@@ -536,12 +590,11 @@ def _test_quant_output(self, quantization_config, expected_slice):
536590 inputs = self .get_dummy_inputs (torch_device )
537591 output = pipe (** inputs )[0 ].flatten ()
538592 output_slice = np .concatenate ((output [:16 ], output [- 16 :]))
593+
539594 self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
540595
541596 def test_quantization (self ):
542- # fmt: off
543- expected_slice = np .array ([0.3633 , - 0.1357 , - 0.0188 , - 0.249 , - 0.4688 , 0.5078 , - 0.1289 , - 0.6914 , 0.4551 ])
544- # fmt: on
597+ expected_slice = np .array ([0. , - 0.1357 , - 0.0188 , - 0.249 , - 0.4688 , 0.5078 , - 0.1289 , - 0.6914 , 0.4551 ])
545598
546599 quantization_config = FinegrainedFP8Config (weight_block_size = (32 , 32 ), modules_to_not_convert = ["x_embedder" , "proj_out" ])
547600 self ._test_quant_output (quantization_config , expected_slice )
@@ -562,7 +615,7 @@ def test_serialization(self):
562615 torch .cuda .empty_cache ()
563616 torch .cuda .synchronize ()
564617 transformer = FluxTransformer2DModel .from_pretrained (
565- tmp_dir , torch_dtype = torch .bfloat16 , use_safetensors = False
618+ tmp_dir , torch_dtype = torch .bfloat16 , use_safetensors = False , device_map = torch_device
566619 )
567620 pipe .transformer = transformer
568621
0 commit comments