@@ -629,12 +629,8 @@ def _test_quant_type(self, quantization_config, expected_slice):
629629        output  =  pipe (** inputs )[0 ].flatten ()
630630        output_slice  =  np .concatenate ((output [:16 ], output [- 16 :]))
631631
632-         for  weight  in  [
633-             pipe .transformer .x_embedder .weight ,
634-             pipe .transformer .transformer_blocks [0 ].ff .net [2 ].weight ,
635-             pipe .transformer .transformer_blocks [- 1 ].ff .net [2 ].weight ,
636-         ]:
637-             self .assertTrue (isinstance (weight , AffineQuantizedTensor ))
632+         weight  =  pipe .transformer .x_embedder .weight 
633+         self .assertTrue (isinstance (weight , AffineQuantizedTensor ))
638634        self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
639635
640636    def  test_quantization (self ):
@@ -643,7 +639,7 @@ def test_quantization(self):
643639            ("int8wo" , np .array ([0.0505 , 0.0742 , 0.1367 , 0.0429 , 0.0585 , 0.1386 , 0.0585 , 0.0703 , 0.1367 , 0.0566 , 0.0703 , 0.1464 , 0.0546 , 0.0703 , 0.1425 , 0.0546 , 0.3535 , 0.7578 , 0.5000 , 0.4062 , 0.7656 , 0.5117 , 0.4121 , 0.7656 , 0.5117 , 0.3984 , 0.7578 , 0.5234 , 0.4023 , 0.7382 , 0.5390 , 0.4570 ])),
644640            ("int8dq" , np .array ([0.0546 , 0.0761 , 0.1386 , 0.0488 , 0.0644 , 0.1425 , 0.0605 , 0.0742 , 0.1406 , 0.0625 , 0.0722 , 0.1523 , 0.0625 , 0.0742 , 0.1503 , 0.0605 , 0.3886 , 0.7968 , 0.5507 , 0.4492 , 0.7890 , 0.5351 , 0.4316 , 0.8007 , 0.5390 , 0.4179 , 0.8281 , 0.5820 , 0.4531 , 0.7812 , 0.5703 , 0.4921 ])),
645641        ]
646- 
642+   
647643        if  TorchAoConfig ._is_cuda_capability_atleast_8_9 ():
648644            QUANTIZATION_TYPES_TO_TEST .extend ([
649645                ("float8wo_e4m3" , np .array ([0.0546 , 0.0722 , 0.1328 , 0.0468 , 0.0585 , 0.1367 , 0.0605 , 0.0703 , 0.1328 , 0.0625 , 0.0703 , 0.1445 , 0.0585 , 0.0703 , 0.1406 , 0.0605 , 0.3496 , 0.7109 , 0.4843 , 0.4042 , 0.7226 , 0.5000 , 0.4160 , 0.7031 , 0.4824 , 0.3886 , 0.6757 , 0.4667 , 0.3710 , 0.6679 , 0.4902 , 0.4238 ])),
@@ -672,10 +668,41 @@ def test_serialization(self):
672668
673669        with  tempfile .TemporaryDirectory () as  tmp_dir :
674670            pipe .save_pretrained (tmp_dir , safe_serialization = False )
675-             loaded_pipe  =  FluxPipeline .from_pretrained (tmp_dir , use_safetensors = False ).to (torch_device )
671+             del  pipe 
672+             gc .collect ()
673+             torch .cuda .empty_cache ()
674+             torch .cuda .synchronize ()
675+             loaded_pipe  =  FluxPipeline .from_pretrained (tmp_dir , use_safetensors = False )
676+             loaded_pipe .enable_model_cpu_offload ()
676677
677678        weight  =  loaded_pipe .transformer .x_embedder .weight 
678679        self .assertTrue (isinstance (weight , AffineQuantizedTensor ))
679680
680681        loaded_output  =  loaded_pipe (** inputs )[0 ].flatten ()
681682        self .assertTrue (np .allclose (output , loaded_output , atol = 1e-3 , rtol = 1e-3 ))
683+     
684+     def  test_memory_footprint_int4wo (self ):
685+         # The original checkpoints are in bf16 and about 24 GB 
686+         expected_memory_in_gb  =  6.0 
687+         quantization_config  =  TorchAoConfig ("int4wo" )
688+         transformer  =  FluxTransformer2DModel .from_pretrained (
689+             "black-forest-labs/FLUX.1-dev" ,
690+             subfolder = "transformer" ,
691+             quantization_config = quantization_config ,
692+             torch_dtype = torch .bfloat16 ,
693+         )
694+         int4wo_memory_in_gb  =  get_model_size_in_bytes (transformer ) /  1024 ** 3 
695+         self .assertTrue (int4wo_memory_in_gb  <  expected_memory_in_gb )
696+     
697+     def  test_memory_footprint_int8wo (self ):
698+         # The original checkpoints are in bf16 and about 24 GB 
699+         expected_memory_in_gb  =  12.0 
700+         quantization_config  =  TorchAoConfig ("int8wo" )
701+         transformer  =  FluxTransformer2DModel .from_pretrained (
702+             "black-forest-labs/FLUX.1-dev" ,
703+             subfolder = "transformer" ,
704+             quantization_config = quantization_config ,
705+             torch_dtype = torch .bfloat16 ,
706+         )
707+         int8wo_memory_in_gb  =  get_model_size_in_bytes (transformer ) /  1024 ** 3 
708+         self .assertTrue (int8wo_memory_in_gb  <  expected_memory_in_gb )
0 commit comments