Skip to content

Commit 497c56d

Browse files
committed
update
1 parent bb2fb98 commit 497c56d

File tree

2 files changed

+45
-42
lines changed

2 files changed

+45
-42
lines changed

src/diffusers/quantizers/torchao/torchao_quantizer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ def validate_environment(self, *args, **kwargs):
9393
raise ImportError(
9494
"Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`"
9595
)
96+
torchao_version = version.parse(importlib.metadata.version("torch"))
97+
if torchao_version < version.parse("0.7.0"):
98+
raise RuntimeError(
99+
f"The minimum required version of `torchao` is 0.7.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
100+
)
96101

97102
self.offload = False
98103

tests/quantization/torchao/test_torchao.py

Lines changed: 40 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def forward(self, input, *args, **kwargs):
7575
if is_torchao_available():
7676
from torchao.dtypes import AffineQuantizedTensor
7777
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
78+
from torchao.utils import get_model_size_in_bytes
7879

7980

8081
@require_torch
@@ -138,11 +139,13 @@ def get_dummy_components(self, quantization_config: TorchAoConfig):
138139
quantization_config=quantization_config,
139140
torch_dtype=torch.bfloat16,
140141
)
141-
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
142-
text_encoder_2 = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
142+
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
143+
text_encoder_2 = T5EncoderModel.from_pretrained(
144+
model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16
145+
)
143146
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
144147
tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2")
145-
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")
148+
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16)
146149
scheduler = FlowMatchEulerDiscreteScheduler()
147150

148151
return {
@@ -211,7 +214,7 @@ def get_dummy_tensor_inputs(self, device=None, seed: int = 0):
211214
def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: List[float]):
212215
components = self.get_dummy_components(quantization_config)
213216
pipe = FluxPipeline(**components)
214-
pipe.to(device=torch_device, dtype=torch.bfloat16)
217+
pipe.to(device=torch_device)
215218

216219
inputs = self.get_dummy_inputs(torch_device)
217220
output = pipe(**inputs)[0]
@@ -315,21 +318,33 @@ def test_offload(self):
315318

316319
def test_modules_to_not_convert(self):
317320
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])
318-
quantized_model = FluxTransformer2DModel.from_pretrained(
321+
quantized_model_with_not_convert = FluxTransformer2DModel.from_pretrained(
319322
"hf-internal-testing/tiny-flux-pipe",
320323
subfolder="transformer",
321324
quantization_config=quantization_config,
322325
torch_dtype=torch.bfloat16,
323326
)
324327

325-
unquantized_layer = quantized_model.transformer_blocks[0].ff.net[2]
328+
unquantized_layer = quantized_model_with_not_convert.transformer_blocks[0].ff.net[2]
326329
self.assertTrue(isinstance(unquantized_layer, torch.nn.Linear))
327330
self.assertFalse(isinstance(unquantized_layer.weight, AffineQuantizedTensor))
328331
self.assertEqual(unquantized_layer.weight.dtype, torch.bfloat16)
329332

330-
quantized_layer = quantized_model.proj_out
333+
quantized_layer = quantized_model_with_not_convert.proj_out
331334
self.assertTrue(isinstance(quantized_layer.weight, AffineQuantizedTensor))
332-
self.assertEqual(quantized_layer.weight.layout_tensor.data.dtype, torch.int8)
335+
336+
quantization_config = TorchAoConfig("int8_weight_only")
337+
quantized_model = FluxTransformer2DModel.from_pretrained(
338+
"hf-internal-testing/tiny-flux-pipe",
339+
subfolder="transformer",
340+
quantization_config=quantization_config,
341+
torch_dtype=torch.bfloat16,
342+
)
343+
344+
size_quantized_with_not_convert = get_model_size_in_bytes(quantized_model_with_not_convert)
345+
size_quantized = get_model_size_in_bytes(quantized_model)
346+
347+
self.assertTrue(size_quantized < size_quantized_with_not_convert)
333348

334349
def test_training(self):
335350
quantization_config = TorchAoConfig("int8_weight_only")
@@ -380,23 +395,6 @@ def test_torch_compile(self):
380395
# Note: Seems to require higher tolerance
381396
self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3))
382397

383-
@staticmethod
384-
def _get_memory_footprint(module):
385-
quantized_param_memory = 0.0
386-
unquantized_param_memory = 0.0
387-
388-
for param in module.parameters():
389-
if param.__class__.__name__ == "AffineQuantizedTensor":
390-
data, scale, zero_point = param.layout_tensor.get_plain()
391-
quantized_param_memory += data.numel() + data.element_size()
392-
quantized_param_memory += scale.numel() + scale.element_size()
393-
quantized_param_memory += zero_point.numel() + zero_point.element_size()
394-
else:
395-
unquantized_param_memory += param.data.numel() * param.data.element_size()
396-
397-
total_memory = quantized_param_memory + unquantized_param_memory
398-
return total_memory, quantized_param_memory, unquantized_param_memory
399-
400398
def test_memory_footprint(self):
401399
r"""
402400
A simple test to check if the model conversion has been done correctly by checking on the
@@ -407,20 +405,18 @@ def test_memory_footprint(self):
407405
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"))["transformer"]
408406
transformer_bf16 = self.get_dummy_components(None)["transformer"]
409407

410-
total_int4wo, quantized_int4wo, unquantized_int4wo = self._get_memory_footprint(transformer_int4wo)
411-
total_int4wo_gs32, quantized_int4wo_gs32, unquantized_int4wo_gs32 = self._get_memory_footprint(
412-
transformer_int4wo_gs32
413-
)
414-
total_int8wo, quantized_int8wo, unquantized_int8wo = self._get_memory_footprint(transformer_int8wo)
415-
total_bf16, quantized_bf16, unquantized_bf16 = self._get_memory_footprint(transformer_bf16)
416-
417-
self.assertTrue(quantized_bf16 == 0 and total_bf16 == unquantized_bf16)
418-
# int4wo_gs32 has smaller group size, so more groups -> more scales and zero points
419-
self.assertTrue(total_int8wo < total_bf16 < total_int4wo_gs32)
420-
# int4 with default group size quantized very few linear layers compared to a smaller group size of 32
421-
self.assertTrue(quantized_int4wo < quantized_int4wo_gs32 and unquantized_int4wo > unquantized_int4wo_gs32)
408+
total_int4wo = get_model_size_in_bytes(transformer_int4wo)
409+
total_int4wo_gs32 = get_model_size_in_bytes(transformer_int4wo_gs32)
410+
total_int8wo = get_model_size_in_bytes(transformer_int8wo)
411+
total_bf16 = get_model_size_in_bytes(transformer_bf16)
412+
413+
# Latter has smaller group size, so more groups -> more scales and zero points
414+
self.assertTrue(total_int4wo < total_int4wo_gs32)
422415
# int8 quantizes more layers compare to int4 with default group size
423-
self.assertTrue(quantized_int8wo < quantized_int4wo)
416+
self.assertTrue(total_int8wo < total_int4wo)
417+
# int4wo does not quantize too many layers because of default group size, but for the layers it does
418+
# there is additional overhead of scales and zero points
419+
self.assertTrue(total_bf16 < total_int4wo)
424420

425421
def test_wrong_config(self):
426422
with self.assertRaises(ValueError):
@@ -555,11 +551,13 @@ def get_dummy_components(self, quantization_config: TorchAoConfig):
555551
quantization_config=quantization_config,
556552
torch_dtype=torch.bfloat16,
557553
)
558-
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
559-
text_encoder_2 = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
554+
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
555+
text_encoder_2 = T5EncoderModel.from_pretrained(
556+
model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16
557+
)
560558
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
561559
tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2")
562-
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")
560+
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16)
563561
scheduler = FlowMatchEulerDiscreteScheduler()
564562

565563
return {
@@ -591,7 +589,7 @@ def get_dummy_inputs(self, device: torch.device, seed: int = 0):
591589

592590
def _test_quant_type(self, quantization_config, expected_slice):
593591
components = self.get_dummy_components(quantization_config)
594-
pipe = FluxPipeline(**components).to(dtype=torch.bfloat16)
592+
pipe = FluxPipeline(**components)
595593
pipe.enable_model_cpu_offload()
596594

597595
inputs = self.get_dummy_inputs(torch_device)

0 commit comments

Comments
 (0)