diff --git a/optimum/commands/export/executorch.py b/optimum/commands/export/executorch.py index 09a907f..80d6a4c 100644 --- a/optimum/commands/export/executorch.py +++ b/optimum/commands/export/executorch.py @@ -76,12 +76,14 @@ def parse_args_executorch(parser): required_group.add_argument( "--qlinear", type=str, - choices=["8da4w", "4w", "8w"], + choices=["8da4w", "4w", "8w", "8da8w", "8da4w,8da8w"], required=False, help=( "Quantization config for decoder linear layers.\n\n" "Options:\n" " 8da4w - 8-bit dynamic activation, 4-bit weight\n" + " 8da8w - 8-bit dynamic activation, 8-bit weight\n" + " 8da4w,8da8w - 8-bit dynamic activation, 4-bit weight and 8-bit weight\n" " 4w - 4-bit weight only\n" " 8w - 8-bit weight only" ), @@ -92,12 +94,14 @@ def parse_args_executorch(parser): required_group.add_argument( "--qlinear_encoder", type=str, - choices=["8da4w", "4w", "8w"], + choices=["8da4w", "4w", "8w", "8da8w", "8da4w,8da8w"], required=False, help=( "Quantization config for linear layers.\n\n" "Options:\n" " 8da4w - 8-bit dynamic activation, 4-bit weight\n" + " 8da8w - 8-bit dynamic activation, 8-bit weight\n" + " 8da4w,8da8w - 8-bit dynamic activation, 4-bit weight; fallback on 8-bit dynamic activation, 8-bit weight per-channel where group size doesn't divide block size cleanly \n" " 4w - 4-bit weight only\n" " 8w - 8-bit weight only" ), @@ -148,15 +152,15 @@ def run(self): if self.args.qlinear: kwargs["qlinear"] = self.args.qlinear if self.args.qlinear_group_size: - kwargs["qlinear_group_size"] = self.args.qlinear + kwargs["qlinear_group_size"] = self.args.qlinear_group_size if self.args.qlinear_encoder: - kwargs["qlinear_encoder"] = self.args.qlinear + kwargs["qlinear_encoder"] = self.args.qlinear_encoder if self.args.qlinear_encoder_group_size: - kwargs["qlinear_encoder_group_size"] = self.args.qlinear + kwargs["qlinear_encoder_group_size"] = self.args.qlinear_encoder_group_size if self.args.qembedding: kwargs["qembedding"] = self.args.qembedding if self.args.qembedding_group_size: - kwargs["qembedding_group_size"] = self.args.qembedding + kwargs["qembedding_group_size"] = self.args.qembedding_group_size if self.args.max_seq_len: kwargs["max_seq_len"] = self.args.max_seq_len diff --git a/optimum/exporters/executorch/quantization.py b/optimum/exporters/executorch/quantization.py index 395ec55..ded5178 100644 --- a/optimum/exporters/executorch/quantization.py +++ b/optimum/exporters/executorch/quantization.py @@ -40,7 +40,7 @@ def quantize_model_( if qlinear_config == "8w": assert ( qembedding_group_size == 0 - ), "8-bit embedding quantization only supports per-channel at the moment, please use qembedding_group_size = 0." + ), "8-bit embedding quantization only supports per-token at the moment, please use qembedding_group_size = 0." if qembedding_group_size == 0: embedding_weight_granularity = PerAxis(0) else: @@ -67,30 +67,87 @@ def quantize_model_( ) if qlinear_config: + + def build_linear_config(config_key: str, granularity): + if config_key == "8da4w": + return Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=granularity, + ) + if config_key == "4w": + return IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=granularity, + ) + if config_key == "8w": + return IntxWeightOnlyConfig( + weight_dtype=torch.int8, + granularity=granularity, + ) + if config_key == "8da8w": + return Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int8, + weight_granularity=PerAxis(0), + ) + raise ValueError(f"Unsupported linear quantization config '{config_key}'.") + + qlinear_configs = [cfg.strip() for cfg in qlinear_config.split(",")] + if any(cfg == "" for cfg in qlinear_configs): + raise ValueError("Linear quantization config entries must be non-empty.") + if len(qlinear_configs) > 2: + raise ValueError("Expected at most one fallback linear quantization config, got more than one comma.") + + primary_linear_config_key = qlinear_configs[0] + fallback_linear_config_key = qlinear_configs[1] if len(qlinear_configs) == 2 else None + if qlinear_group_size == 0: linear_weight_granularity = PerAxis(0) + if fallback_linear_config_key is not None: + logging.warning( + "qlinear_group_size is 0, fallback linear config will not be used as all layers will be quantized with per-axis granularity." + ) + fallback_linear_config_key = None else: - assert qlinear_group_size % 2 == 0, "Linear quantization group size must be a multiple of 2." + assert ( + qlinear_group_size % 2 == 0 + ), f"Linear quantization group size must be a multiple of 2, got {qlinear_group_size}." linear_weight_granularity = PerGroup(qlinear_group_size) logging.info("Quantizing linear layers.") - linear_config = { - "8da4w": Int8DynamicActivationIntxWeightConfig( - weight_dtype=torch.int4, - weight_granularity=linear_weight_granularity, - ), - "4w": IntxWeightOnlyConfig( - weight_dtype=torch.int4, - granularity=linear_weight_granularity, - ), - "8w": IntxWeightOnlyConfig( - weight_dtype=torch.int8, - granularity=linear_weight_granularity, - ), - }[qlinear_config] + primary_linear_config = build_linear_config(primary_linear_config_key, linear_weight_granularity) + + # First, quantize layers that are compatible with group quantization + def per_group_filter(module, fqn): + if isinstance(module, torch.nn.Linear): + # Check if hidden dimension is divisible by group size + # For Linear layers, weight shape is [out_features, in_features] + # Group quantization typically applies to the in_features dimension (dim=1) + return qlinear_group_size == 0 or (module.weight.shape[1] % qlinear_group_size == 0) + return False + quantize_( eager_model, - linear_config, + primary_linear_config, + filter_fn=per_group_filter, ) + # Then, quantize incompatible layers using the fallback per-axis config + if fallback_linear_config_key is not None: + fallback_linear_config = build_linear_config(fallback_linear_config_key, PerAxis(0)) + + def per_token_filter(module, fqn): + if isinstance(module, torch.nn.Linear): + return module.weight.shape[1] % qlinear_group_size != 0 + return False + + logging.info( + f"Applying fallback linear config '{fallback_linear_config_key}' (per-axis)" + f" to layers incompatible with group size {qlinear_group_size}." + ) + quantize_( + eager_model, + fallback_linear_config, + filter_fn=per_token_filter, + ) + unwrap_tensor_subclass(eager_model) diff --git a/optimum/exporters/executorch/tasks/multimodal_text_to_text.py b/optimum/exporters/executorch/tasks/multimodal_text_to_text.py index 4449570..d3bf007 100644 --- a/optimum/exporters/executorch/tasks/multimodal_text_to_text.py +++ b/optimum/exporters/executorch/tasks/multimodal_text_to_text.py @@ -14,6 +14,7 @@ import json +import logging import os.path import torchao @@ -201,15 +202,24 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs): qembedding_group_size = kwargs.get("qembedding_group_size", None) # Quantize decoder linear weights. + if qlinear_config: + logging.info("Quantizing decoder linears...") quantize_decoder_kwargs = { "eager_model": getattr(eager_model, decoder_name), "qlinear_config": qlinear_config, } + quantize_lm_head_kwargs = { + "eager_model": eager_model.lm_head, + "qlinear_config": qlinear_config, + } if qlinear_group_size is not None: quantize_decoder_kwargs["qlinear_group_size"] = qlinear_group_size quantize_model_(**quantize_decoder_kwargs) + quantize_model_(**quantize_lm_head_kwargs) # Quantize encoder linear weights. + if qlinear_encoder_config: + logging.info("Quantizing encoder linears...") quantize_encoder_kwargs = { "eager_model": getattr(eager_model, encoder_name), "qlinear_config": qlinear_encoder_config, @@ -218,19 +228,17 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs): quantize_encoder_kwargs["qlinear_group_size"] = qlinear_encoder_group_size quantize_model_(**quantize_encoder_kwargs) - # TODO: quantize other parts of the model, e.g. MultimodalProjector? - # Quantize decoder embeddings. + if qembedding_config: + logging.info("Quantizing decoder embeddings...") quantize_decoder_embedding_kwargs = { - "eager_model": getattr(eager_model, decoder_name), + "eager_model": eager_model, "qembedding_config": qembedding_config, } if qembedding_group_size is not None: quantize_decoder_embedding_kwargs["qembedding_group_size"] = qembedding_group_size quantize_model_(**quantize_decoder_embedding_kwargs) - # TODO: quantize encoder embeddings. - return MultiModalTextToTextExportableModule( model=eager_model, modality="audio" if audio_encoder_name else "vision", diff --git a/tests/models/test_modeling_gemma3.py b/tests/models/test_modeling_gemma3.py index ff507bb..eeb6834 100644 --- a/tests/models/test_modeling_gemma3.py +++ b/tests/models/test_modeling_gemma3.py @@ -309,9 +309,9 @@ def test_gemma3_image_vision_with_custom_sdpa_kv_cache_8da4w_8we(self): use_custom_kv_cache=True, qlinear="8da4w", qlinear_group_size=32, - # Can't quantize the encoder a the moment, hidden dim of 4304 doesn't fit ExecuTorch's - # XNNPack 32-group size quantized kernels. See https://github.com/pytorch/executorch/issues/14221. - qembedding_config="8w", + qlinear_encoder="8da4w,8da8w", + qlinear_encoder_group_size=32, + qembedding="8w", ) # Generate