Skip to content

Commit a10f19c

Browse files
committed
update
1 parent d0b718a commit a10f19c

File tree

2 files changed

+36
-9
lines changed

2 files changed

+36
-9
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
720720
if hf_quantizer is not None:
721721
if device_map is not None:
722722
raise NotImplementedError(
723-
"Currently, `device_map` is automatically inferred for quantized models. Support for providing `device_map` as an input will be added in the future."
723+
"Currently, providing `device_map` is not supported for quantized models. Providing `device_map` as an input will be added in the future."
724724
)
725725

726726
hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)

tests/quantization/torchao/test_torchao.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -629,12 +629,8 @@ def _test_quant_type(self, quantization_config, expected_slice):
629629
output = pipe(**inputs)[0].flatten()
630630
output_slice = np.concatenate((output[:16], output[-16:]))
631631

632-
for weight in [
633-
pipe.transformer.x_embedder.weight,
634-
pipe.transformer.transformer_blocks[0].ff.net[2].weight,
635-
pipe.transformer.transformer_blocks[-1].ff.net[2].weight,
636-
]:
637-
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
632+
weight = pipe.transformer.x_embedder.weight
633+
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
638634
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
639635

640636
def test_quantization(self):
@@ -643,7 +639,7 @@ def test_quantization(self):
643639
("int8wo", np.array([0.0505, 0.0742, 0.1367, 0.0429, 0.0585, 0.1386, 0.0585, 0.0703, 0.1367, 0.0566, 0.0703, 0.1464, 0.0546, 0.0703, 0.1425, 0.0546, 0.3535, 0.7578, 0.5000, 0.4062, 0.7656, 0.5117, 0.4121, 0.7656, 0.5117, 0.3984, 0.7578, 0.5234, 0.4023, 0.7382, 0.5390, 0.4570])),
644640
("int8dq", np.array([0.0546, 0.0761, 0.1386, 0.0488, 0.0644, 0.1425, 0.0605, 0.0742, 0.1406, 0.0625, 0.0722, 0.1523, 0.0625, 0.0742, 0.1503, 0.0605, 0.3886, 0.7968, 0.5507, 0.4492, 0.7890, 0.5351, 0.4316, 0.8007, 0.5390, 0.4179, 0.8281, 0.5820, 0.4531, 0.7812, 0.5703, 0.4921])),
645641
]
646-
642+
647643
if TorchAoConfig._is_cuda_capability_atleast_8_9():
648644
QUANTIZATION_TYPES_TO_TEST.extend([
649645
("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])),
@@ -672,10 +668,41 @@ def test_serialization(self):
672668

673669
with tempfile.TemporaryDirectory() as tmp_dir:
674670
pipe.save_pretrained(tmp_dir, safe_serialization=False)
675-
loaded_pipe = FluxPipeline.from_pretrained(tmp_dir, use_safetensors=False).to(torch_device)
671+
del pipe
672+
gc.collect()
673+
torch.cuda.empty_cache()
674+
torch.cuda.synchronize()
675+
loaded_pipe = FluxPipeline.from_pretrained(tmp_dir, use_safetensors=False)
676+
loaded_pipe.enable_model_cpu_offload()
676677

677678
weight = loaded_pipe.transformer.x_embedder.weight
678679
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
679680

680681
loaded_output = loaded_pipe(**inputs)[0].flatten()
681682
self.assertTrue(np.allclose(output, loaded_output, atol=1e-3, rtol=1e-3))
683+
684+
def test_memory_footprint_int4wo(self):
685+
# The original checkpoints are in bf16 and about 24 GB
686+
expected_memory_in_gb = 6.0
687+
quantization_config = TorchAoConfig("int4wo")
688+
transformer = FluxTransformer2DModel.from_pretrained(
689+
"black-forest-labs/FLUX.1-dev",
690+
subfolder="transformer",
691+
quantization_config=quantization_config,
692+
torch_dtype=torch.bfloat16,
693+
)
694+
int4wo_memory_in_gb = get_model_size_in_bytes(transformer) / 1024**3
695+
self.assertTrue(int4wo_memory_in_gb < expected_memory_in_gb)
696+
697+
def test_memory_footprint_int8wo(self):
698+
# The original checkpoints are in bf16 and about 24 GB
699+
expected_memory_in_gb = 12.0
700+
quantization_config = TorchAoConfig("int8wo")
701+
transformer = FluxTransformer2DModel.from_pretrained(
702+
"black-forest-labs/FLUX.1-dev",
703+
subfolder="transformer",
704+
quantization_config=quantization_config,
705+
torch_dtype=torch.bfloat16,
706+
)
707+
int8wo_memory_in_gb = get_model_size_in_bytes(transformer) / 1024**3
708+
self.assertTrue(int8wo_memory_in_gb < expected_memory_in_gb)

0 commit comments

Comments
 (0)