Skip to content

Commit d1b6405

Browse files
committed
update tests
1 parent 651666d commit d1b6405

File tree

1 file changed

+27
-14
lines changed

1 file changed

+27
-14
lines changed

tests/quantization/torchao/test_torchao.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -577,20 +577,25 @@ def tearDown(self):
577577
torch.cuda.empty_cache()
578578

579579
def get_dummy_components(self, quantization_config: TorchAoConfig):
580+
# This is just for convenience, so that we can modify it at one place for custom environments and locally testing
581+
cache_dir = None
580582
model_id = "black-forest-labs/FLUX.1-dev"
581583
transformer = FluxTransformer2DModel.from_pretrained(
582584
model_id,
583585
subfolder="transformer",
584586
quantization_config=quantization_config,
585587
torch_dtype=torch.bfloat16,
588+
cache_dir=cache_dir,
589+
)
590+
text_encoder = CLIPTextModel.from_pretrained(
591+
model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16, cache_dir=cache_dir
586592
)
587-
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
588593
text_encoder_2 = T5EncoderModel.from_pretrained(
589-
model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16
594+
model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16, cache_dir=cache_dir
590595
)
591-
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
592-
tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2")
593-
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16)
596+
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer", cache_dir=cache_dir)
597+
tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2", cache_dir=cache_dir)
598+
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16, cache_dir=cache_dir)
594599
scheduler = FlowMatchEulerDiscreteScheduler()
595600

596601
return {
@@ -624,9 +629,9 @@ def _test_quant_type(self, quantization_config, expected_slice):
624629
components = self.get_dummy_components(quantization_config)
625630
pipe = FluxPipeline(**components)
626631
pipe.enable_model_cpu_offload()
627-
632+
628633
weight = pipe.transformer.transformer_blocks[0].ff.net[2].weight
629-
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
634+
self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)))
630635

631636
inputs = self.get_dummy_inputs(torch_device)
632637
output = pipe(**inputs)[0].flatten()
@@ -643,7 +648,7 @@ def test_quantization(self):
643648
if TorchAoConfig._is_cuda_capability_atleast_8_9():
644649
QUANTIZATION_TYPES_TO_TEST.extend([
645650
("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])),
646-
("fp5_e3m1", np.array([0.0527, 0.0742, 0.1289, 0.0449, 0.0625, 0.1308, 0.0585, 0.0742, 0.1269, 0.0585, 0.0722, 0.1328, 0.0566, 0.0742, 0.1347, 0.0585, 0.3691, 0.7578, 0.5429, 0.4355, 0.7695, 0.5546, 0.4414, 0.7578, 0.5468, 0.4179, 0.7265, 0.5273, 0.3945, 0.6992, 0.5234, 0.4316])),
651+
("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])),
647652
])
648653
# fmt: on
649654

@@ -667,29 +672,35 @@ def test_serialization(self):
667672
output = pipe(**inputs)[0].flatten()
668673

669674
with tempfile.TemporaryDirectory() as tmp_dir:
670-
pipe.save_pretrained(tmp_dir, safe_serialization=False)
671-
del pipe
675+
pipe.transformer.save_pretrained(tmp_dir, safe_serialization=False)
676+
pipe.remove_all_hooks()
677+
del pipe.transformer
672678
gc.collect()
673679
torch.cuda.empty_cache()
674680
torch.cuda.synchronize()
675-
loaded_pipe = FluxPipeline.from_pretrained(tmp_dir, use_safetensors=False)
676-
loaded_pipe.enable_model_cpu_offload()
681+
transformer = FluxTransformer2DModel.from_pretrained(
682+
tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False
683+
)
684+
pipe.transformer = transformer
685+
pipe.enable_model_cpu_offload()
677686

678-
weight = loaded_pipe.transformer.x_embedder.weight
687+
weight = transformer.x_embedder.weight
679688
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
680689

681-
loaded_output = loaded_pipe(**inputs)[0].flatten()
690+
loaded_output = pipe(**inputs)[0].flatten()
682691
self.assertTrue(np.allclose(output, loaded_output, atol=1e-3, rtol=1e-3))
683692

684693
def test_memory_footprint_int4wo(self):
685694
# The original checkpoints are in bf16 and about 24 GB
686695
expected_memory_in_gb = 6.0
687696
quantization_config = TorchAoConfig("int4wo")
697+
cache_dir = None
688698
transformer = FluxTransformer2DModel.from_pretrained(
689699
"black-forest-labs/FLUX.1-dev",
690700
subfolder="transformer",
691701
quantization_config=quantization_config,
692702
torch_dtype=torch.bfloat16,
703+
cache_dir=cache_dir,
693704
)
694705
int4wo_memory_in_gb = get_model_size_in_bytes(transformer) / 1024**3
695706
self.assertTrue(int4wo_memory_in_gb < expected_memory_in_gb)
@@ -698,11 +709,13 @@ def test_memory_footprint_int8wo(self):
698709
# The original checkpoints are in bf16 and about 24 GB
699710
expected_memory_in_gb = 12.0
700711
quantization_config = TorchAoConfig("int8wo")
712+
cache_dir = None
701713
transformer = FluxTransformer2DModel.from_pretrained(
702714
"black-forest-labs/FLUX.1-dev",
703715
subfolder="transformer",
704716
quantization_config=quantization_config,
705717
torch_dtype=torch.bfloat16,
718+
cache_dir=cache_dir,
706719
)
707720
int8wo_memory_in_gb = get_model_size_in_bytes(transformer) / 1024**3
708721
self.assertTrue(int8wo_memory_in_gb < expected_memory_in_gb)

0 commit comments

Comments
 (0)