Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions optimum/commands/export/executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,14 @@ def parse_args_executorch(parser):
required_group.add_argument(
"--qlinear",
type=str,
choices=["8da4w", "4w", "8w"],
choices=["8da4w", "4w", "8w", "8da8w", "8da4w,8da8w"],
required=False,
help=(
"Quantization config for decoder linear layers.\n\n"
"Options:\n"
" 8da4w - 8-bit dynamic activation, 4-bit weight\n"
" 8da8w - 8-bit dynamic activation, 8-bit weight\n"
" 8da4w,8da8w - 8-bit dynamic activation, 4-bit weight and 8-bit weight\n"
" 4w - 4-bit weight only\n"
" 8w - 8-bit weight only"
),
Expand All @@ -92,12 +94,14 @@ def parse_args_executorch(parser):
required_group.add_argument(
"--qlinear_encoder",
type=str,
choices=["8da4w", "4w", "8w"],
choices=["8da4w", "4w", "8w", "8da8w", "8da4w,8da8w"],
required=False,
help=(
"Quantization config for linear layers.\n\n"
"Options:\n"
" 8da4w - 8-bit dynamic activation, 4-bit weight\n"
" 8da8w - 8-bit dynamic activation, 8-bit weight\n"
" 8da4w,8da8w - 8-bit dynamic activation, 4-bit weight and 8-bit weight\n"
" 4w - 4-bit weight only\n"
" 8w - 8-bit weight only"
),
Expand Down Expand Up @@ -148,15 +152,15 @@ def run(self):
if self.args.qlinear:
kwargs["qlinear"] = self.args.qlinear
if self.args.qlinear_group_size:
kwargs["qlinear_group_size"] = self.args.qlinear
kwargs["qlinear_group_size"] = self.args.qlinear_group_size
if self.args.qlinear_encoder:
kwargs["qlinear_encoder"] = self.args.qlinear
kwargs["qlinear_encoder"] = self.args.qlinear_encoder
if self.args.qlinear_encoder_group_size:
kwargs["qlinear_encoder_group_size"] = self.args.qlinear
kwargs["qlinear_encoder_group_size"] = self.args.qlinear_encoder_group_size
if self.args.qembedding:
kwargs["qembedding"] = self.args.qembedding
if self.args.qembedding_group_size:
kwargs["qembedding_group_size"] = self.args.qembedding
kwargs["qembedding_group_size"] = self.args.qembedding_group_size
if self.args.max_seq_len:
kwargs["max_seq_len"] = self.args.max_seq_len

Expand Down
94 changes: 78 additions & 16 deletions optimum/exporters/executorch/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,30 +67,92 @@ def quantize_model_(
)

if qlinear_config:
def build_linear_config(config_key: str, granularity):
if config_key == "8da4w":
return Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4,
weight_granularity=granularity,
)
if config_key == "4w":
return IntxWeightOnlyConfig(
weight_dtype=torch.int4,
granularity=granularity,
)
if config_key == "8w":
return IntxWeightOnlyConfig(
weight_dtype=torch.int8,
granularity=granularity,
)
if config_key == "8da8w":
return Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int8,
weight_granularity=PerAxis(0),
)
raise ValueError(f"Unsupported linear quantization config '{config_key}'.")

qlinear_configs = [cfg.strip() for cfg in qlinear_config.split(",")]
if any(cfg == "" for cfg in qlinear_configs):
raise ValueError("Linear quantization config entries must be non-empty.")
if len(qlinear_configs) > 2:
raise ValueError(
"Expected at most one fallback linear quantization config, got more than one comma."
)

primary_linear_config_key = qlinear_configs[0]
fallback_linear_config_key = qlinear_configs[1] if len(qlinear_configs) == 2 else None

if qlinear_group_size == 0:
linear_weight_granularity = PerAxis(0)
if fallback_linear_config_key is not None:
logging.warning(
"qlinear_group_size is 0, fallback linear config will not be used as all layers will be quantized with per-axis granularity."
)
fallback_linear_config_key = None
else:
assert qlinear_group_size % 2 == 0, "Linear quantization group size must be a multiple of 2."
assert qlinear_group_size % 2 == 0, f"Linear quantization group size must be a multiple of 2, got {qlinear_group_size}."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is groupsixe a multiple of 2? Shouldn't it be a multiple of 32?

linear_weight_granularity = PerGroup(qlinear_group_size)

logging.info("Quantizing linear layers.")
linear_config = {
"8da4w": Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4,
weight_granularity=linear_weight_granularity,
),
"4w": IntxWeightOnlyConfig(
weight_dtype=torch.int4,
granularity=linear_weight_granularity,
),
"8w": IntxWeightOnlyConfig(
weight_dtype=torch.int8,
granularity=linear_weight_granularity,
),
}[qlinear_config]
primary_linear_config = build_linear_config(
primary_linear_config_key, linear_weight_granularity
)

# First, quantize layers that are compatible with group quantization
def quant_filter(module, fqn):
if isinstance(module, torch.nn.Linear):
# Check if hidden dimension is divisible by group size
# For Linear layers, weight shape is [out_features, in_features]
# Group quantization typically applies to the in_features dimension (dim=1)
return qlinear_group_size == 0 or (module.weight.shape[1] % qlinear_group_size == 0)
return False

quantize_(
eager_model,
linear_config,
primary_linear_config,
filter_fn=quant_filter,
)

# Then, quantize incompatible layers using the fallback per-axis config
if fallback_linear_config_key is not None:
fallback_linear_config = build_linear_config(
fallback_linear_config_key, PerAxis(0)
)

def per_channel_filter(module, fqn):
if isinstance(module, torch.nn.Linear):
# Only quantize layers that are NOT compatible with group quantization
# and haven't been quantized yet
return not quant_filter(module, fqn)
return False

logging.info(
f"Applying fallback linear config '{fallback_linear_config_key}' (per-axis)"
f" to layers incompatible with group size {qlinear_group_size}."
)
quantize_(
eager_model,
fallback_linear_config,
filter_fn=per_channel_filter,
)

unwrap_tensor_subclass(eager_model)
14 changes: 9 additions & 5 deletions optimum/exporters/executorch/tasks/multimodal_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,19 +218,23 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
quantize_encoder_kwargs["qlinear_group_size"] = qlinear_encoder_group_size
quantize_model_(**quantize_encoder_kwargs)

# TODO: quantize other parts of the model, e.g. MultimodalProjector?

# Quantize decoder embeddings.
quantize_decoder_embedding_kwargs = {
"eager_model": getattr(eager_model, decoder_name),
"eager_model": eager_model,
"qembedding_config": qembedding_config,
}
if qembedding_group_size is not None:
quantize_decoder_embedding_kwargs["qembedding_group_size"] = qembedding_group_size
quantize_model_(**quantize_decoder_embedding_kwargs)

# TODO: quantize encoder embeddings.

# Quantize lm_head
if hasattr(eager_model, "lm_head") and qlinear_config is not None:
quantize_model_(
eager_model=eager_model.lm_head,
qlinear_config=qlinear_config,
qlinear_group_size=qlinear_group_size if qlinear_group_size is not None else 0,
)
print(eager_model)
return MultiModalTextToTextExportableModule(
model=eager_model,
modality="audio" if audio_encoder_name else "vision",
Expand Down
Loading