@@ -168,13 +168,16 @@ class QuantizationConfig:
168168 alpha : float = 1.0 # SmoothQuant alpha
169169 lowrank : int = 32 # SVDQuant lowrank
170170 quantize_mha : bool = False
171+ compress : bool = False
171172
172173 def validate (self ) -> None :
173174 """Validate configuration consistency."""
174175 if self .format == QuantFormat .FP8 and self .collect_method != CollectMethod .DEFAULT :
175176 raise NotImplementedError ("Only 'default' collect method is implemented for FP8." )
176177 if self .quantize_mha and self .format == QuantFormat .INT8 :
177178 raise ValueError ("MHA quantization is only supported for FP8, not INT8." )
179+ if self .compress and self .format == QuantFormat .INT8 :
180+ raise ValueError ("Compression is only supported for FP8 and FP4, not INT8." )
178181
179182
180183@dataclass
@@ -766,6 +769,9 @@ def create_argument_parser() -> argparse.ArgumentParser:
766769 # FP8 quantization with ONNX export
767770 %(prog)s --model sd3-medium --format fp8 --onnx-dir ./onnx_models/
768771
772+ # FP8 quantization with weight compression (reduces memory footprint)
773+ %(prog)s --model flux-dev --format fp8 --compress
774+
769775 # Quantize LTX-Video model with full multi-stage pipeline
770776 %(prog)s --model ltx-video-dev --format fp8 --batch-size 1 --calib-size 32
771777
@@ -835,6 +841,11 @@ def create_argument_parser() -> argparse.ArgumentParser:
835841 quant_group .add_argument (
836842 "--quantize-mha" , action = "store_true" , help = "Quantizing MHA into FP8 if its True"
837843 )
844+ quant_group .add_argument (
845+ "--compress" ,
846+ action = "store_true" ,
847+ help = "Compress quantized weights to reduce memory footprint (FP8/FP4 only)" ,
848+ )
838849
839850 calib_group = parser .add_argument_group ("Calibration Configuration" )
840851 calib_group .add_argument ("--batch-size" , type = int , default = 2 , help = "Batch size for calibration" )
@@ -894,6 +905,7 @@ def main() -> None:
894905 alpha = args .alpha ,
895906 lowrank = args .lowrank ,
896907 quantize_mha = args .quantize_mha ,
908+ compress = args .compress ,
897909 )
898910
899911 calib_config = CalibrationConfig (
@@ -940,6 +952,12 @@ def forward_loop(mod):
940952
941953 quantizer .quantize_model (backbone , backbone_quant_config , forward_loop )
942954
955+ # Compress model weights if requested (only for FP8/FP4)
956+ if quant_config .compress :
957+ logger .info ("Compressing model weights to reduce memory footprint..." )
958+ mtq .compress (backbone )
959+ logger .info ("Model compression completed" )
960+
943961 export_manager .save_checkpoint (backbone )
944962 export_manager .export_onnx (
945963 pipe ,
0 commit comments