@@ -131,7 +131,9 @@ def tearDown(self):
131131        gc .collect ()
132132        torch .cuda .empty_cache ()
133133
134-     def  get_dummy_components (self , quantization_config : TorchAoConfig , model_id : str  =  "hf-internal-testing/tiny-flux-pipe" ):
134+     def  get_dummy_components (
135+         self , quantization_config : TorchAoConfig , model_id : str  =  "hf-internal-testing/tiny-flux-pipe" 
136+     ):
135137        transformer  =  FluxTransformer2DModel .from_pretrained (
136138            model_id ,
137139            subfolder = "transformer" ,
@@ -436,7 +438,9 @@ def test_memory_footprint(self):
436438        """ 
437439        for  model_id  in  ["hf-internal-testing/tiny-flux-pipe" , "hf-internal-testing/tiny-flux-sharded" ]:
438440            transformer_int4wo  =  self .get_dummy_components (TorchAoConfig ("int4wo" ), model_id = model_id )["transformer" ]
439-             transformer_int4wo_gs32  =  self .get_dummy_components (TorchAoConfig ("int4wo" , group_size = 32 ), model_id = model_id )["transformer" ]
441+             transformer_int4wo_gs32  =  self .get_dummy_components (
442+                 TorchAoConfig ("int4wo" , group_size = 32 ), model_id = model_id 
443+             )["transformer" ]
440444            transformer_int8wo  =  self .get_dummy_components (TorchAoConfig ("int8wo" ), model_id = model_id )["transformer" ]
441445            transformer_bf16  =  self .get_dummy_components (None , model_id = model_id )["transformer" ]
442446
@@ -654,7 +658,7 @@ def test_quantization(self):
654658            gc .collect ()
655659            torch .cuda .empty_cache ()
656660            torch .cuda .synchronize ()
657-      
661+ 
658662    def  test_serialization (self ):
659663        quantization_config  =  TorchAoConfig ("int8wo" )
660664        components  =  self .get_dummy_components (quantization_config )
@@ -673,6 +677,6 @@ def test_serialization(self):
673677
674678        weight  =  loaded_pipe .transformer .x_embedder .weight 
675679        self .assertTrue (isinstance (weight , AffineQuantizedTensor ))
676-          
680+ 
677681        loaded_output  =  loaded_pipe (** inputs )[0 ].flatten ()
678682        self .assertTrue (np .allclose (output , loaded_output , atol = 1e-3 , rtol = 1e-3 ))
0 commit comments