Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/models/llama2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,13 @@ python -m examples.models.llama2.export_llama \
--params "${LLAMA_PARAMS:?}" \
--use_sdpa_with_kv_cache \
-X \
--spin_qmode 8da4w_output_8da8w \
--spin_group_size 32 \
--preq_mode 8da4w_output_8da8w \
--preq_group_size 32 \
--max_seq_length 2048 \
--output_name "llama3_2.pte" \
-kv \
-d fp32 \
--spin_embedding_quantize 8,0 \
--preq_embedding_quantize 8,0 \
--use_spin_quant native \
--metadata '{"append_eos_to_prompt": 0, "get_bos_id":128000, "get_eos_ids":[128009, 128001], "get_n_bos": 0, "get_n_eos": 0}'
```
Expand Down
1 change: 1 addition & 0 deletions examples/models/llama2/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ runtime.python_library(
"export_llama_lib.py",
"model.py",
"source_transformation/apply_spin_quant_r1_r2.py",
"source_transformation/pre_quantization.py",
"source_transformation/prune_output.py",
"source_transformation/quantize.py",
"source_transformation/rms_norm.py",
Expand Down
12 changes: 6 additions & 6 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,25 +381,25 @@ def build_args_parser() -> argparse.ArgumentParser:
)

parser.add_argument(
"--spin_qmode",
"--preq_mode",
type=str,
default=None,
choices=["8da4w", "8da4w_output_8da8w"],
help="Quantization mode for SpinQuant. Only support 8da4w and 8da4w_output_8da8w right now.",
help="Quantization mode used for pre-quantized checkpoint. Only support 8da4w and 8da4w_output_8da8w right now.",
)

parser.add_argument(
"--spin_group_size",
"--preq_group_size",
type=int,
default=32,
help="group_size for SpinQuant weight quantization",
help="group_size for pre-quantized checkpoint weight quantization",
)

parser.add_argument(
"--spin_embedding_quantize",
"--preq_embedding_quantize",
default="8,0",
type=str,
help="type of embedding quantization for SpinQuant, '<bitwidth>,<groupsize>', e.g., '8,1024'.",
help="type of embedding quantization for pre-quantized checkpoint, '<bitwidth>,<groupsize>', e.g., '8,1024'.",
)

parser.add_argument(
Expand Down
40 changes: 20 additions & 20 deletions examples/models/llama2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,20 +191,20 @@ def __init__(self, **kwargs):
)
elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant:
print("Using SPIN quantization.")
assert hasattr(self.args, "spin_qmode"), "spin_qmode must be specified"
assert self.args.spin_qmode in [
assert hasattr(self.args, "preq_mode"), "preq_mode must be specified"
assert self.args.preq_mode in [
"8da4w",
"8da4w_output_8da8w",
], f"Quantization mode {self.args.spin_qmode} is not compatible with SpinQuant."
], f"Quantization mode {self.args.preq_mode} is not compatible with SpinQuant."
assert hasattr(
self.args, "spin_group_size"
), "spin_group_size must be specified"
self.args, "preq_group_size"
), "preq_group_size must be specified"
assert hasattr(
self.args, "dtype_override"
), "dtype_override must be specified"
from .source_transformation.spin_quant import (
sanitize_checkpoint_from_spinquant,
transform_linear_for_spinquant,
from .source_transformation.pre_quantization import (
sanitize_checkpoint_from_pre_quantization,
transform_linear_for_pre_quantization,
)

mapping = {
Expand All @@ -214,31 +214,31 @@ def __init__(self, **kwargs):
}

# Transform the output layer first if needed.
if self.args.spin_qmode == "8da4w_output_8da8w":
from .source_transformation.spin_quant import (
transform_output_linear_for_spinquant,
if self.args.preq_mode == "8da4w_output_8da8w":
from .source_transformation.pre_quantization import (
transform_output_linear_for_pre_quantization,
)

self.model_ = transform_output_linear_for_spinquant(
self.model_ = transform_output_linear_for_pre_quantization(
module=self.model_,
checkpoint=checkpoint,
dtype=mapping[self.args.dtype_override],
)

self.model_ = transform_linear_for_spinquant(
self.model_ = transform_linear_for_pre_quantization(
self.model_,
checkpoint,
self.args.spin_group_size,
self.args.preq_group_size,
mapping[self.args.dtype_override],
)

embedding_bit_width, embedding_group_size = None, None
if hasattr(self.args, "spin_embedding_quantize"):
if hasattr(self.args, "preq_embedding_quantize"):
embedding_bit_width, embedding_group_size = (
self.args.spin_embedding_quantize.split(",")
self.args.preq_embedding_quantize.split(",")
)
from .source_transformation.spin_quant import (
transform_embedding_for_spinquant,
from .source_transformation.pre_quantization import (
transform_embedding_for_pre_quantization,
)

if (
Expand All @@ -250,15 +250,15 @@ def __init__(self, **kwargs):
else:
embedding_group_size = int(embedding_group_size)

self.model_ = transform_embedding_for_spinquant(
self.model_ = transform_embedding_for_pre_quantization(
self.model_,
checkpoint,
mapping[self.args.dtype_override],
int(embedding_bit_width),
embedding_group_size,
)

sanitize_checkpoint_from_spinquant(checkpoint)
sanitize_checkpoint_from_pre_quantization(checkpoint)

# assign=True: load params/buffers by assignment instead of performing an in-place copy.
# Because we are using device="meta", tensors do not have memory associated with them
Expand Down
191 changes: 191 additions & 0 deletions examples/models/llama2/source_transformation/pre_quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

# Helper functions for tranforming the model to be able to load pre-quantized checkpoints.

from typing import Any, Optional

import torch
from torch import nn

from torchao.quantization.GPTQ import _check_linear_int4_k, Int8DynActInt4WeightLinear
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter

from .quantize import Int8DynActInt8WeightLinear, QuantizedGroupEmbedding


def _replace_linear_with_linear_8da4w_for_pre_quantization(
module: torch.nn.Module,
checkpoint: Any,
group_size: int,
precision: torch.dtype,
scales_precision: torch.dtype,
):
def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
# Only replace linear layers where the checkpoint contains explicit scales
scales_key = f"{cur_fqn}.scales"
if isinstance(child, nn.Linear) and scales_key in checkpoint:
assert _check_linear_int4_k(child.in_features, group_size)
assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8
assert checkpoint[scales_key].dtype == scales_precision
return True
return False

def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
new_linear = Int8DynActInt4WeightLinear(
child.in_features,
child.out_features,
bias=False,
device=child.weight.device,
groupsize=group_size,
precision=precision,
scales_precision=scales_precision,
)
return new_linear

_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)


def transform_linear_for_pre_quantization(
module: torch.nn.Module,
checkpoint: Any,
group_size: int,
dtype: torch.dtype,
) -> torch.nn.Module:
"""
Transform the model to be able to load pre-quantized checkpoints that
are quantized with the given group size and quantization mode for
linear layers.
"""

if group_size not in [32, 64, 128, 256]:
raise ValueError(
f"Group size {group_size} is not supported for pre-quantized checkpoint."
)
_replace_linear_with_linear_8da4w_for_pre_quantization(
module,
checkpoint,
group_size,
dtype,
dtype,
)
return module


def _replace_output_linear_with_linear_int8_for_pre_quantization(
module: torch.nn.Module,
checkpoint: Any,
dtype: torch.dtype,
):
def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
scales_key = f"{cur_fqn}.scales"
if (
isinstance(child, nn.Linear)
and scales_key in checkpoint
and "output" in cur_fqn
):
assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8
assert checkpoint[scales_key].dtype == dtype
return True
return False

def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
new_linear = Int8DynActInt8WeightLinear(
device=child.weight.device,
in_features=child.in_features,
out_features=child.out_features,
precision=dtype,
bias=False,
)
return new_linear

_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)


def transform_output_linear_for_pre_quantization(
module: torch.nn.Module,
checkpoint: Any,
dtype: torch.dtype,
) -> torch.nn.Module:
"""
Transform the model to be able to load pre-quantized checkpoints that
has the output layer quantized per-channel.
"""
_replace_output_linear_with_linear_int8_for_pre_quantization(
module,
checkpoint,
dtype,
)
return module


def _replace_embedding_with_quantized_group_embedding_for_pre_quantization(
module: torch.nn.Module,
checkpoint: Any,
dtype: torch.dtype,
bit_width: int,
group_size: Optional[int] = None,
):
def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
# Only replace embedding layers where the checkpoint contains explicit scales
scales_key = f"{cur_fqn}.scales"
if isinstance(child, nn.Embedding) and scales_key in checkpoint:
assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8
assert checkpoint[scales_key].dtype == torch.float32
return True
return False

def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
new_embedding = QuantizedGroupEmbedding(
device=child.weight.device,
vocab_size=child.weight.shape[0],
embedding_dim=child.weight.shape[1],
group_size=group_size,
dtype=dtype,
packed=False, # TODO(lunwenh): support packed embedding for pre-quantized
)
return new_embedding

_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)


def transform_embedding_for_pre_quantization(
module: torch.nn.Module,
checkpoint: Any,
dtype: torch.dtype,
bit_width: int,
group_size: Optional[int] = None,
) -> torch.nn.Module:
"""
Transform the model to be able to load pre-quantized checkpoints that
are quantized with the given bit_width and group size for embedding.
"""
if group_size is not None and group_size not in [0, 32, 64, 128, 256]:
raise ValueError(
f"Group size {group_size} is not supported for pre-quantized checkpoint."
)
_replace_embedding_with_quantized_group_embedding_for_pre_quantization(
module,
checkpoint,
dtype,
bit_width,
group_size,
)
return module


def sanitize_checkpoint_from_pre_quantization(
checkpoint: Any,
):
"""
Sanitize the pre-quantized checkpoint.
- Converts all tensors to contiguous format
- Squeeze all tensors
"""
for k, v in checkpoint.items():
checkpoint[k] = torch.squeeze(v.contiguous())
Loading
Loading