Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions optimum/commands/export/executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,14 @@ def parse_args_executorch(parser):
required_group.add_argument(
"--qlinear",
type=str,
choices=["8da4w", "4w", "8w"],
choices=["8da4w", "4w", "8w", "8da8w", "8da4w,8da8w"],
required=False,
help=(
"Quantization config for decoder linear layers.\n\n"
"Options:\n"
" 8da4w - 8-bit dynamic activation, 4-bit weight\n"
" 8da8w - 8-bit dynamic activation, 8-bit weight\n"
" 8da4w,8da8w - 8-bit dynamic activation, 4-bit weight and 8-bit weight\n"
" 4w - 4-bit weight only\n"
" 8w - 8-bit weight only"
),
Expand All @@ -92,12 +94,14 @@ def parse_args_executorch(parser):
required_group.add_argument(
"--qlinear_encoder",
type=str,
choices=["8da4w", "4w", "8w"],
choices=["8da4w", "4w", "8w", "8da8w", "8da4w,8da8w"],
required=False,
help=(
"Quantization config for linear layers.\n\n"
"Options:\n"
" 8da4w - 8-bit dynamic activation, 4-bit weight\n"
" 8da8w - 8-bit dynamic activation, 8-bit weight\n"
" 8da4w,8da8w - 8-bit dynamic activation, 4-bit weight; fallback on 8-bit dynamic activation, 8-bit weight per-channel where group size doesn't divide block size cleanly \n"
" 4w - 4-bit weight only\n"
" 8w - 8-bit weight only"
),
Expand Down Expand Up @@ -148,15 +152,15 @@ def run(self):
if self.args.qlinear:
kwargs["qlinear"] = self.args.qlinear
if self.args.qlinear_group_size:
kwargs["qlinear_group_size"] = self.args.qlinear
kwargs["qlinear_group_size"] = self.args.qlinear_group_size
if self.args.qlinear_encoder:
kwargs["qlinear_encoder"] = self.args.qlinear
kwargs["qlinear_encoder"] = self.args.qlinear_encoder
if self.args.qlinear_encoder_group_size:
kwargs["qlinear_encoder_group_size"] = self.args.qlinear
kwargs["qlinear_encoder_group_size"] = self.args.qlinear_encoder_group_size
if self.args.qembedding:
kwargs["qembedding"] = self.args.qembedding
if self.args.qembedding_group_size:
kwargs["qembedding_group_size"] = self.args.qembedding
kwargs["qembedding_group_size"] = self.args.qembedding_group_size
if self.args.max_seq_len:
kwargs["max_seq_len"] = self.args.max_seq_len

Expand Down
91 changes: 74 additions & 17 deletions optimum/exporters/executorch/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def quantize_model_(
if qlinear_config == "8w":
assert (
qembedding_group_size == 0
), "8-bit embedding quantization only supports per-channel at the moment, please use qembedding_group_size = 0."
), "8-bit embedding quantization only supports per-token at the moment, please use qembedding_group_size = 0."
if qembedding_group_size == 0:
embedding_weight_granularity = PerAxis(0)
else:
Expand All @@ -67,30 +67,87 @@ def quantize_model_(
)

if qlinear_config:

def build_linear_config(config_key: str, granularity):
if config_key == "8da4w":
return Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4,
weight_granularity=granularity,
)
if config_key == "4w":
return IntxWeightOnlyConfig(
weight_dtype=torch.int4,
granularity=granularity,
)
if config_key == "8w":
return IntxWeightOnlyConfig(
weight_dtype=torch.int8,
granularity=granularity,
)
if config_key == "8da8w":
return Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int8,
weight_granularity=PerAxis(0),
)
raise ValueError(f"Unsupported linear quantization config '{config_key}'.")

qlinear_configs = [cfg.strip() for cfg in qlinear_config.split(",")]
if any(cfg == "" for cfg in qlinear_configs):
raise ValueError("Linear quantization config entries must be non-empty.")
if len(qlinear_configs) > 2:
raise ValueError("Expected at most one fallback linear quantization config, got more than one comma.")

primary_linear_config_key = qlinear_configs[0]
fallback_linear_config_key = qlinear_configs[1] if len(qlinear_configs) == 2 else None

if qlinear_group_size == 0:
linear_weight_granularity = PerAxis(0)
if fallback_linear_config_key is not None:
logging.warning(
"qlinear_group_size is 0, fallback linear config will not be used as all layers will be quantized with per-axis granularity."
)
fallback_linear_config_key = None
else:
assert qlinear_group_size % 2 == 0, "Linear quantization group size must be a multiple of 2."
assert (
qlinear_group_size % 2 == 0
), f"Linear quantization group size must be a multiple of 2, got {qlinear_group_size}."
linear_weight_granularity = PerGroup(qlinear_group_size)

logging.info("Quantizing linear layers.")
linear_config = {
"8da4w": Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4,
weight_granularity=linear_weight_granularity,
),
"4w": IntxWeightOnlyConfig(
weight_dtype=torch.int4,
granularity=linear_weight_granularity,
),
"8w": IntxWeightOnlyConfig(
weight_dtype=torch.int8,
granularity=linear_weight_granularity,
),
}[qlinear_config]
primary_linear_config = build_linear_config(primary_linear_config_key, linear_weight_granularity)

# First, quantize layers that are compatible with group quantization
def per_group_filter(module, fqn):
if isinstance(module, torch.nn.Linear):
# Check if hidden dimension is divisible by group size
# For Linear layers, weight shape is [out_features, in_features]
# Group quantization typically applies to the in_features dimension (dim=1)
return qlinear_group_size == 0 or (module.weight.shape[1] % qlinear_group_size == 0)
return False

quantize_(
eager_model,
linear_config,
primary_linear_config,
filter_fn=per_group_filter,
)

# Then, quantize incompatible layers using the fallback per-axis config
if fallback_linear_config_key is not None:
fallback_linear_config = build_linear_config(fallback_linear_config_key, PerAxis(0))

def per_token_filter(module, fqn):
if isinstance(module, torch.nn.Linear):
return module.weight.shape[1] % qlinear_group_size != 0
return False

logging.info(
f"Applying fallback linear config '{fallback_linear_config_key}' (per-axis)"
f" to layers incompatible with group size {qlinear_group_size}."
)
quantize_(
eager_model,
fallback_linear_config,
filter_fn=per_token_filter,
)

unwrap_tensor_subclass(eager_model)
18 changes: 13 additions & 5 deletions optimum/exporters/executorch/tasks/multimodal_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


import json
import logging
import os.path

import torchao
Expand Down Expand Up @@ -201,15 +202,24 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
qembedding_group_size = kwargs.get("qembedding_group_size", None)

# Quantize decoder linear weights.
if qlinear_config:
logging.info("Quantizing decoder linears...")
quantize_decoder_kwargs = {
"eager_model": getattr(eager_model, decoder_name),
"qlinear_config": qlinear_config,
}
quantize_lm_head_kwargs = {
"eager_model": eager_model.lm_head,
"qlinear_config": qlinear_config,
}
Comment on lines +211 to +214
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you guard this by whether eager_model has lm_head?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, curious though is there a model without lm_head?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah voxtral doesn't have lm_head

if qlinear_group_size is not None:
quantize_decoder_kwargs["qlinear_group_size"] = qlinear_group_size
quantize_model_(**quantize_decoder_kwargs)
quantize_model_(**quantize_lm_head_kwargs)

# Quantize encoder linear weights.
if qlinear_encoder_config:
logging.info("Quantizing encoder linears...")
quantize_encoder_kwargs = {
"eager_model": getattr(eager_model, encoder_name),
"qlinear_config": qlinear_encoder_config,
Expand All @@ -218,19 +228,17 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
quantize_encoder_kwargs["qlinear_group_size"] = qlinear_encoder_group_size
quantize_model_(**quantize_encoder_kwargs)

# TODO: quantize other parts of the model, e.g. MultimodalProjector?

# Quantize decoder embeddings.
if qembedding_config:
logging.info("Quantizing decoder embeddings...")
quantize_decoder_embedding_kwargs = {
"eager_model": getattr(eager_model, decoder_name),
"eager_model": eager_model,
"qembedding_config": qembedding_config,
}
if qembedding_group_size is not None:
quantize_decoder_embedding_kwargs["qembedding_group_size"] = qembedding_group_size
quantize_model_(**quantize_decoder_embedding_kwargs)

# TODO: quantize encoder embeddings.

return MultiModalTextToTextExportableModule(
model=eager_model,
modality="audio" if audio_encoder_name else "vision",
Expand Down
6 changes: 3 additions & 3 deletions tests/models/test_modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,9 @@ def test_gemma3_image_vision_with_custom_sdpa_kv_cache_8da4w_8we(self):
use_custom_kv_cache=True,
qlinear="8da4w",
qlinear_group_size=32,
# Can't quantize the encoder a the moment, hidden dim of 4304 doesn't fit ExecuTorch's
# XNNPack 32-group size quantized kernels. See https://github.com/pytorch/executorch/issues/14221.
qembedding_config="8w",
qlinear_encoder="8da4w,8da8w",
qlinear_encoder_group_size=32,
qembedding="8w",
)

# Generate
Expand Down
Loading