@@ -577,20 +577,25 @@ def tearDown(self):
577577        torch .cuda .empty_cache ()
578578
579579    def  get_dummy_components (self , quantization_config : TorchAoConfig ):
580+         # This is just for convenience, so that we can modify it at one place for custom environments and locally testing 
581+         cache_dir  =  None 
580582        model_id  =  "black-forest-labs/FLUX.1-dev" 
581583        transformer  =  FluxTransformer2DModel .from_pretrained (
582584            model_id ,
583585            subfolder = "transformer" ,
584586            quantization_config = quantization_config ,
585587            torch_dtype = torch .bfloat16 ,
588+             cache_dir = cache_dir ,
589+         )
590+         text_encoder  =  CLIPTextModel .from_pretrained (
591+             model_id , subfolder = "text_encoder" , torch_dtype = torch .bfloat16 , cache_dir = cache_dir 
586592        )
587-         text_encoder  =  CLIPTextModel .from_pretrained (model_id , subfolder = "text_encoder" , torch_dtype = torch .bfloat16 )
588593        text_encoder_2  =  T5EncoderModel .from_pretrained (
589-             model_id , subfolder = "text_encoder_2" , torch_dtype = torch .bfloat16 
594+             model_id , subfolder = "text_encoder_2" , torch_dtype = torch .bfloat16 ,  cache_dir = cache_dir 
590595        )
591-         tokenizer  =  CLIPTokenizer .from_pretrained (model_id , subfolder = "tokenizer" )
592-         tokenizer_2  =  AutoTokenizer .from_pretrained (model_id , subfolder = "tokenizer_2" )
593-         vae  =  AutoencoderKL .from_pretrained (model_id , subfolder = "vae" , torch_dtype = torch .bfloat16 )
596+         tokenizer  =  CLIPTokenizer .from_pretrained (model_id , subfolder = "tokenizer" ,  cache_dir = cache_dir )
597+         tokenizer_2  =  AutoTokenizer .from_pretrained (model_id , subfolder = "tokenizer_2" ,  cache_dir = cache_dir )
598+         vae  =  AutoencoderKL .from_pretrained (model_id , subfolder = "vae" , torch_dtype = torch .bfloat16 ,  cache_dir = cache_dir )
594599        scheduler  =  FlowMatchEulerDiscreteScheduler ()
595600
596601        return  {
@@ -624,9 +629,9 @@ def _test_quant_type(self, quantization_config, expected_slice):
624629        components  =  self .get_dummy_components (quantization_config )
625630        pipe  =  FluxPipeline (** components )
626631        pipe .enable_model_cpu_offload ()
627-          
632+ 
628633        weight  =  pipe .transformer .transformer_blocks [0 ].ff .net [2 ].weight 
629-         self .assertTrue (isinstance (weight , AffineQuantizedTensor ))
634+         self .assertTrue (isinstance (weight , ( AffineQuantizedTensor ,  LinearActivationQuantizedTensor ) ))
630635
631636        inputs  =  self .get_dummy_inputs (torch_device )
632637        output  =  pipe (** inputs )[0 ].flatten ()
@@ -643,7 +648,7 @@ def test_quantization(self):
643648        if  TorchAoConfig ._is_cuda_capability_atleast_8_9 ():
644649            QUANTIZATION_TYPES_TO_TEST .extend ([
645650                ("float8wo_e4m3" , np .array ([0.0546 , 0.0722 , 0.1328 , 0.0468 , 0.0585 , 0.1367 , 0.0605 , 0.0703 , 0.1328 , 0.0625 , 0.0703 , 0.1445 , 0.0585 , 0.0703 , 0.1406 , 0.0605 , 0.3496 , 0.7109 , 0.4843 , 0.4042 , 0.7226 , 0.5000 , 0.4160 , 0.7031 , 0.4824 , 0.3886 , 0.6757 , 0.4667 , 0.3710 , 0.6679 , 0.4902 , 0.4238 ])),
646-                 ("fp5_e3m1" , np .array ([0.0527 , 0.0742  , 0.1289  , 0.0449 , 0.0625  , 0.1308  , 0.0585  , 0.0742  , 0.1269  , 0.0585  , 0.0722  , 0.1328 , 0.0566 , 0.0742 , 0.1347  , 0.0585  , 0.3691  , 0.7578  , 0.5429  , 0.4355  , 0.7695  , 0.5546  , 0.4414  , 0.7578  , 0.5468  , 0.4179  , 0.7265  , 0.5273  , 0.3945  , 0.6992 , 0.5234  , 0.4316  ])),
651+                 ("fp5_e3m1" , np .array ([0.0527 , 0.0762  , 0.1309  , 0.0449 , 0.0645  , 0.1328  , 0.0566  , 0.0723  , 0.125  , 0.0566  , 0.0703  , 0.1328 , 0.0566 , 0.0742 , 0.1348  , 0.0566  , 0.3633  , 0.7617  , 0.5273  , 0.4277  , 0.7891  , 0.5469  , 0.4375  , 0.8008  , 0.5586  , 0.4336  , 0.7383  , 0.5156  , 0.3906  , 0.6992 , 0.5156  , 0.4375  ])),
647652            ])
648653        # fmt: on 
649654
@@ -667,29 +672,35 @@ def test_serialization(self):
667672        output  =  pipe (** inputs )[0 ].flatten ()
668673
669674        with  tempfile .TemporaryDirectory () as  tmp_dir :
670-             pipe .save_pretrained (tmp_dir , safe_serialization = False )
671-             del  pipe 
675+             pipe .transformer .save_pretrained (tmp_dir , safe_serialization = False )
676+             pipe .remove_all_hooks ()
677+             del  pipe .transformer 
672678            gc .collect ()
673679            torch .cuda .empty_cache ()
674680            torch .cuda .synchronize ()
675-             loaded_pipe  =  FluxPipeline .from_pretrained (tmp_dir , use_safetensors = False )
676-             loaded_pipe .enable_model_cpu_offload ()
681+             transformer  =  FluxTransformer2DModel .from_pretrained (
682+                 tmp_dir , torch_dtype = torch .bfloat16 , use_safetensors = False 
683+             )
684+             pipe .transformer  =  transformer 
685+             pipe .enable_model_cpu_offload ()
677686
678-         weight  =  loaded_pipe . transformer .x_embedder .weight 
687+         weight  =  transformer .x_embedder .weight 
679688        self .assertTrue (isinstance (weight , AffineQuantizedTensor ))
680689
681-         loaded_output  =  loaded_pipe (** inputs )[0 ].flatten ()
690+         loaded_output  =  pipe (** inputs )[0 ].flatten ()
682691        self .assertTrue (np .allclose (output , loaded_output , atol = 1e-3 , rtol = 1e-3 ))
683692
684693    def  test_memory_footprint_int4wo (self ):
685694        # The original checkpoints are in bf16 and about 24 GB 
686695        expected_memory_in_gb  =  6.0 
687696        quantization_config  =  TorchAoConfig ("int4wo" )
697+         cache_dir  =  None 
688698        transformer  =  FluxTransformer2DModel .from_pretrained (
689699            "black-forest-labs/FLUX.1-dev" ,
690700            subfolder = "transformer" ,
691701            quantization_config = quantization_config ,
692702            torch_dtype = torch .bfloat16 ,
703+             cache_dir = cache_dir ,
693704        )
694705        int4wo_memory_in_gb  =  get_model_size_in_bytes (transformer ) /  1024 ** 3 
695706        self .assertTrue (int4wo_memory_in_gb  <  expected_memory_in_gb )
@@ -698,11 +709,13 @@ def test_memory_footprint_int8wo(self):
698709        # The original checkpoints are in bf16 and about 24 GB 
699710        expected_memory_in_gb  =  12.0 
700711        quantization_config  =  TorchAoConfig ("int8wo" )
712+         cache_dir  =  None 
701713        transformer  =  FluxTransformer2DModel .from_pretrained (
702714            "black-forest-labs/FLUX.1-dev" ,
703715            subfolder = "transformer" ,
704716            quantization_config = quantization_config ,
705717            torch_dtype = torch .bfloat16 ,
718+             cache_dir = cache_dir ,
706719        )
707720        int8wo_memory_in_gb  =  get_model_size_in_bytes (transformer ) /  1024 ** 3 
708721        self .assertTrue (int8wo_memory_in_gb  <  expected_memory_in_gb )
0 commit comments