@@ -443,21 +443,29 @@ def test_memory_footprint(self):
443443 transformer_int8wo = self .get_dummy_components (TorchAoConfig ("int8wo" ), model_id = model_id )["transformer" ]
444444 transformer_bf16 = self .get_dummy_components (None , model_id = model_id )["transformer" ]
445445
446- self .assertTrue (
447- isinstance (transformer_int4wo .transformer_blocks [0 ].ff .net [2 ].weight , AffineQuantizedTensor )
448- )
449- self .assertTrue (
450- isinstance (transformer_int4wo_gs32 .transformer_blocks [0 ].ff .net [2 ].weight , AffineQuantizedTensor )
451- )
452- self .assertTrue (
453- isinstance (transformer_int8wo .transformer_blocks [0 ].ff .net [2 ].weight , AffineQuantizedTensor )
454- )
446+ # Will not quantized all the layers by default due to the model weights shapes not being divisible by group_size=64
447+ for block in transformer_int4wo .transformer_blocks :
448+ self .assertTrue (isinstance (block .ff .net [2 ].weight , AffineQuantizedTensor ))
449+ self .assertTrue (isinstance (block .ff_context .net [2 ].weight , AffineQuantizedTensor ))
450+
451+ # Will quantize all the linear layers except x_embedder
452+ for name , module in transformer_int4wo_gs32 .named_modules ():
453+ if name == "x_embedder" :
454+ print (module )
455+ if isinstance (module , nn .Linear ) and name not in ["x_embedder" ]:
456+ self .assertTrue (isinstance (module .weight , AffineQuantizedTensor ))
457+
458+ # Will quantize all the linear layers
459+ for module in transformer_int8wo .modules ():
460+ if isinstance (module , nn .Linear ):
461+ self .assertTrue (isinstance (module .weight , AffineQuantizedTensor ))
455462
456463 total_int4wo = get_model_size_in_bytes (transformer_int4wo )
457464 total_int4wo_gs32 = get_model_size_in_bytes (transformer_int4wo_gs32 )
458465 total_int8wo = get_model_size_in_bytes (transformer_int8wo )
459466 total_bf16 = get_model_size_in_bytes (transformer_bf16 )
460467
468+ # TODO: refactor to align with other quantization tests
461469 # Latter has smaller group size, so more groups -> more scales and zero points
462470 self .assertTrue (total_int4wo < total_int4wo_gs32 )
463471 # int8 quantizes more layers compare to int4 with default group size
@@ -735,3 +743,60 @@ def test_memory_footprint_int8wo(self):
735743 )
736744 int8wo_memory_in_gb = get_model_size_in_bytes (transformer ) / 1024 ** 3
737745 self .assertTrue (int8wo_memory_in_gb < expected_memory_in_gb )
746+
747+
748+ @require_torch
749+ @require_torch_gpu
750+ @require_torchao_version_greater_or_equal ("0.7.0" )
751+ @slow
752+ @nightly
753+ class SlowTorchAoPreserializedModelTests (unittest .TestCase ):
754+ def tearDown (self ):
755+ gc .collect ()
756+ torch .cuda .empty_cache ()
757+
758+ def get_dummy_inputs (self , device : torch .device , seed : int = 0 ):
759+ if str (device ).startswith ("mps" ):
760+ generator = torch .manual_seed (seed )
761+ else :
762+ generator = torch .Generator ().manual_seed (seed )
763+
764+ inputs = {
765+ "prompt" : "an astronaut riding a horse in space" ,
766+ "height" : 512 ,
767+ "width" : 512 ,
768+ "num_inference_steps" : 20 ,
769+ "output_type" : "np" ,
770+ "generator" : generator ,
771+ }
772+
773+ return inputs
774+
775+ def test_transformer_int8wo (self ):
776+ # fmt: off
777+ expected_slice = np .array ([0.0505 , 0.0742 , 0.1367 , 0.0429 , 0.0585 , 0.1386 , 0.0585 , 0.0703 , 0.1367 , 0.0566 , 0.0703 , 0.1464 , 0.0546 , 0.0703 , 0.1425 , 0.0546 , 0.3535 , 0.7578 , 0.5000 , 0.4062 , 0.7656 , 0.5117 , 0.4121 , 0.7656 , 0.5117 , 0.3984 , 0.7578 , 0.5234 , 0.4023 , 0.7382 , 0.5390 , 0.4570 ])
778+ # fmt: on
779+
780+ # This is just for convenience, so that we can modify it at one place for custom environments and locally testing
781+ cache_dir = None
782+ transformer = FluxTransformer2DModel .from_pretrained (
783+ "hf-internal-testing/FLUX.1-Dev-TorchAO-int8wo-transformer" ,
784+ torch_dtype = torch .bfloat16 ,
785+ use_safetensors = False ,
786+ cache_dir = cache_dir ,
787+ )
788+ pipe = FluxPipeline .from_pretrained (
789+ "black-forest-labs/FLUX.1-dev" , transformer = transformer , torch_dtype = torch .bfloat16 , cache_dir = cache_dir
790+ )
791+ pipe .enable_model_cpu_offload ()
792+
793+ # Verify that all linear layer weights are quantized
794+ for name , module in pipe .transformer .named_modules ():
795+ if isinstance (module , nn .Linear ):
796+ self .assertTrue (isinstance (module .weight , AffineQuantizedTensor ))
797+
798+ # Verify outputs match expected slice
799+ inputs = self .get_dummy_inputs (torch_device )
800+ output = pipe (** inputs )[0 ].flatten ()
801+ output_slice = np .concatenate ((output [:16 ], output [- 16 :]))
802+ self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
0 commit comments