Skip to content

Commit 87bb2fe

Browse files
committed
update
1 parent c6651f9 commit 87bb2fe

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

src/diffusers/quantizers/torchao/torchao_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def validate_environment(self, *args, **kwargs):
125125
def update_torch_dtype(self, torch_dtype):
126126
quant_type = self.quantization_config.quant_type
127127

128-
if quant_type.startswith("int"):
128+
if quant_type.startswith("int") or quant_type.startswith("uint"):
129129
if torch_dtype is not None and torch_dtype != torch.bfloat16:
130130
logger.warning(
131131
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "

tests/quantization/torchao/test_torchao.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,3 +654,25 @@ def test_quantization(self):
654654
gc.collect()
655655
torch.cuda.empty_cache()
656656
torch.cuda.synchronize()
657+
658+
def test_serialization(self):
659+
quantization_config = TorchAoConfig("int8wo")
660+
components = self.get_dummy_components(quantization_config)
661+
pipe = FluxPipeline(**components)
662+
pipe.enable_model_cpu_offload()
663+
664+
weight = pipe.transformer.x_embedder.weight
665+
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
666+
667+
inputs = self.get_dummy_inputs(torch_device)
668+
output = pipe(**inputs)[0].flatten()
669+
670+
with tempfile.TemporaryDirectory() as tmp_dir:
671+
pipe.save_pretrained(tmp_dir, safe_serialization=False)
672+
loaded_pipe = FluxPipeline.from_pretrained(tmp_dir, use_safetensors=False).to(torch_device)
673+
674+
weight = loaded_pipe.transformer.x_embedder.weight
675+
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
676+
677+
loaded_output = loaded_pipe(**inputs)[0].flatten()
678+
self.assertTrue(np.allclose(output, loaded_output, atol=1e-3, rtol=1e-3))

0 commit comments

Comments
 (0)