Skip to content

Commit 41f2bf4

Browse files
authored
Add option to real quantize the model (#473)
Signed-off-by: ajrasane <[email protected]>
1 parent 7b6a15a commit 41f2bf4

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

examples/diffusers/quantization/quantize.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,16 @@ class QuantizationConfig:
168168
alpha: float = 1.0 # SmoothQuant alpha
169169
lowrank: int = 32 # SVDQuant lowrank
170170
quantize_mha: bool = False
171+
compress: bool = False
171172

172173
def validate(self) -> None:
173174
"""Validate configuration consistency."""
174175
if self.format == QuantFormat.FP8 and self.collect_method != CollectMethod.DEFAULT:
175176
raise NotImplementedError("Only 'default' collect method is implemented for FP8.")
176177
if self.quantize_mha and self.format == QuantFormat.INT8:
177178
raise ValueError("MHA quantization is only supported for FP8, not INT8.")
179+
if self.compress and self.format == QuantFormat.INT8:
180+
raise ValueError("Compression is only supported for FP8 and FP4, not INT8.")
178181

179182

180183
@dataclass
@@ -766,6 +769,9 @@ def create_argument_parser() -> argparse.ArgumentParser:
766769
# FP8 quantization with ONNX export
767770
%(prog)s --model sd3-medium --format fp8 --onnx-dir ./onnx_models/
768771
772+
# FP8 quantization with weight compression (reduces memory footprint)
773+
%(prog)s --model flux-dev --format fp8 --compress
774+
769775
# Quantize LTX-Video model with full multi-stage pipeline
770776
%(prog)s --model ltx-video-dev --format fp8 --batch-size 1 --calib-size 32
771777
@@ -835,6 +841,11 @@ def create_argument_parser() -> argparse.ArgumentParser:
835841
quant_group.add_argument(
836842
"--quantize-mha", action="store_true", help="Quantizing MHA into FP8 if its True"
837843
)
844+
quant_group.add_argument(
845+
"--compress",
846+
action="store_true",
847+
help="Compress quantized weights to reduce memory footprint (FP8/FP4 only)",
848+
)
838849

839850
calib_group = parser.add_argument_group("Calibration Configuration")
840851
calib_group.add_argument("--batch-size", type=int, default=2, help="Batch size for calibration")
@@ -894,6 +905,7 @@ def main() -> None:
894905
alpha=args.alpha,
895906
lowrank=args.lowrank,
896907
quantize_mha=args.quantize_mha,
908+
compress=args.compress,
897909
)
898910

899911
calib_config = CalibrationConfig(
@@ -940,6 +952,12 @@ def forward_loop(mod):
940952

941953
quantizer.quantize_model(backbone, backbone_quant_config, forward_loop)
942954

955+
# Compress model weights if requested (only for FP8/FP4)
956+
if quant_config.compress:
957+
logger.info("Compressing model weights to reduce memory footprint...")
958+
mtq.compress(backbone)
959+
logger.info("Model compression completed")
960+
943961
export_manager.save_checkpoint(backbone)
944962
export_manager.export_onnx(
945963
pipe,

0 commit comments

Comments
 (0)