From 3dc62c6bb71f9526b59e3ee8ce9f1eb0febe6db1 Mon Sep 17 00:00:00 2001 From: Lunwen He Date: Tue, 1 Oct 2024 16:31:27 -0700 Subject: [PATCH] update spinquant quantization options to be general purposed pre-quantization (#5797) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/5797 We decided to use the same quantization scheme and checkpoint format for QAT + LoRA. This PR updates related quantization cli options to be general purposed for pre-quantized checkpoints. Differential Revision: D63708762 --- examples/models/llama2/README.md | 6 +- examples/models/llama2/TARGETS | 1 + examples/models/llama2/export_llama_lib.py | 12 +- examples/models/llama2/model.py | 40 ++-- .../source_transformation/pre_quantization.py | 191 ++++++++++++++++++ .../source_transformation/spin_quant.py | 173 ---------------- examples/models/llama2/tests/TARGETS | 4 +- ...py => test_pre_quantization_transforms.py} | 33 ++- pytest.ini | 2 +- 9 files changed, 240 insertions(+), 222 deletions(-) create mode 100644 examples/models/llama2/source_transformation/pre_quantization.py rename examples/models/llama2/tests/{test_spinquant_transforms.py => test_pre_quantization_transforms.py} (86%) diff --git a/examples/models/llama2/README.md b/examples/models/llama2/README.md index 138044d5342..bcca1b82ba4 100644 --- a/examples/models/llama2/README.md +++ b/examples/models/llama2/README.md @@ -162,13 +162,13 @@ python -m examples.models.llama2.export_llama \ --params "${LLAMA_PARAMS:?}" \ --use_sdpa_with_kv_cache \ -X \ - --spin_qmode 8da4w_output_8da8w \ - --spin_group_size 32 \ + --preq_mode 8da4w_output_8da8w \ + --preq_group_size 32 \ --max_seq_length 2048 \ --output_name "llama3_2.pte" \ -kv \ -d fp32 \ - --spin_embedding_quantize 8,0 \ + --preq_embedding_quantize 8,0 \ --use_spin_quant native \ --metadata '{"append_eos_to_prompt": 0, "get_bos_id":128000, "get_eos_ids":[128009, 128001], "get_n_bos": 0, "get_n_eos": 0}' ``` diff --git a/examples/models/llama2/TARGETS b/examples/models/llama2/TARGETS index e2fb5a9177a..0918801f9e1 100644 --- a/examples/models/llama2/TARGETS +++ b/examples/models/llama2/TARGETS @@ -80,6 +80,7 @@ runtime.python_library( "export_llama_lib.py", "model.py", "source_transformation/apply_spin_quant_r1_r2.py", + "source_transformation/pre_quantization.py", "source_transformation/prune_output.py", "source_transformation/quantize.py", "source_transformation/rms_norm.py", diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 24c88d48994..305df7a90dd 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -381,25 +381,25 @@ def build_args_parser() -> argparse.ArgumentParser: ) parser.add_argument( - "--spin_qmode", + "--preq_mode", type=str, default=None, choices=["8da4w", "8da4w_output_8da8w"], - help="Quantization mode for SpinQuant. Only support 8da4w and 8da4w_output_8da8w right now.", + help="Quantization mode used for pre-quantized checkpoint. Only support 8da4w and 8da4w_output_8da8w right now.", ) parser.add_argument( - "--spin_group_size", + "--preq_group_size", type=int, default=32, - help="group_size for SpinQuant weight quantization", + help="group_size for pre-quantized checkpoint weight quantization", ) parser.add_argument( - "--spin_embedding_quantize", + "--preq_embedding_quantize", default="8,0", type=str, - help="type of embedding quantization for SpinQuant, ',', e.g., '8,1024'.", + help="type of embedding quantization for pre-quantized checkpoint, ',', e.g., '8,1024'.", ) parser.add_argument( diff --git a/examples/models/llama2/model.py b/examples/models/llama2/model.py index c48fa98d576..a4081d1bd57 100644 --- a/examples/models/llama2/model.py +++ b/examples/models/llama2/model.py @@ -191,20 +191,20 @@ def __init__(self, **kwargs): ) elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant: print("Using SPIN quantization.") - assert hasattr(self.args, "spin_qmode"), "spin_qmode must be specified" - assert self.args.spin_qmode in [ + assert hasattr(self.args, "preq_mode"), "preq_mode must be specified" + assert self.args.preq_mode in [ "8da4w", "8da4w_output_8da8w", - ], f"Quantization mode {self.args.spin_qmode} is not compatible with SpinQuant." + ], f"Quantization mode {self.args.preq_mode} is not compatible with SpinQuant." assert hasattr( - self.args, "spin_group_size" - ), "spin_group_size must be specified" + self.args, "preq_group_size" + ), "preq_group_size must be specified" assert hasattr( self.args, "dtype_override" ), "dtype_override must be specified" - from .source_transformation.spin_quant import ( - sanitize_checkpoint_from_spinquant, - transform_linear_for_spinquant, + from .source_transformation.pre_quantization import ( + sanitize_checkpoint_from_pre_quantization, + transform_linear_for_pre_quantization, ) mapping = { @@ -214,31 +214,31 @@ def __init__(self, **kwargs): } # Transform the output layer first if needed. - if self.args.spin_qmode == "8da4w_output_8da8w": - from .source_transformation.spin_quant import ( - transform_output_linear_for_spinquant, + if self.args.preq_mode == "8da4w_output_8da8w": + from .source_transformation.pre_quantization import ( + transform_output_linear_for_pre_quantization, ) - self.model_ = transform_output_linear_for_spinquant( + self.model_ = transform_output_linear_for_pre_quantization( module=self.model_, checkpoint=checkpoint, dtype=mapping[self.args.dtype_override], ) - self.model_ = transform_linear_for_spinquant( + self.model_ = transform_linear_for_pre_quantization( self.model_, checkpoint, - self.args.spin_group_size, + self.args.preq_group_size, mapping[self.args.dtype_override], ) embedding_bit_width, embedding_group_size = None, None - if hasattr(self.args, "spin_embedding_quantize"): + if hasattr(self.args, "preq_embedding_quantize"): embedding_bit_width, embedding_group_size = ( - self.args.spin_embedding_quantize.split(",") + self.args.preq_embedding_quantize.split(",") ) - from .source_transformation.spin_quant import ( - transform_embedding_for_spinquant, + from .source_transformation.pre_quantization import ( + transform_embedding_for_pre_quantization, ) if ( @@ -250,7 +250,7 @@ def __init__(self, **kwargs): else: embedding_group_size = int(embedding_group_size) - self.model_ = transform_embedding_for_spinquant( + self.model_ = transform_embedding_for_pre_quantization( self.model_, checkpoint, mapping[self.args.dtype_override], @@ -258,7 +258,7 @@ def __init__(self, **kwargs): embedding_group_size, ) - sanitize_checkpoint_from_spinquant(checkpoint) + sanitize_checkpoint_from_pre_quantization(checkpoint) # assign=True: load params/buffers by assignment instead of performing an in-place copy. # Because we are using device="meta", tensors do not have memory associated with them diff --git a/examples/models/llama2/source_transformation/pre_quantization.py b/examples/models/llama2/source_transformation/pre_quantization.py new file mode 100644 index 00000000000..38937c5ab4e --- /dev/null +++ b/examples/models/llama2/source_transformation/pre_quantization.py @@ -0,0 +1,191 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +# Helper functions for tranforming the model to be able to load pre-quantized checkpoints. + +from typing import Any, Optional + +import torch +from torch import nn + +from torchao.quantization.GPTQ import _check_linear_int4_k, Int8DynActInt4WeightLinear +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter + +from .quantize import Int8DynActInt8WeightLinear, QuantizedGroupEmbedding + + +def _replace_linear_with_linear_8da4w_for_pre_quantization( + module: torch.nn.Module, + checkpoint: Any, + group_size: int, + precision: torch.dtype, + scales_precision: torch.dtype, +): + def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: + # Only replace linear layers where the checkpoint contains explicit scales + scales_key = f"{cur_fqn}.scales" + if isinstance(child, nn.Linear) and scales_key in checkpoint: + assert _check_linear_int4_k(child.in_features, group_size) + assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8 + assert checkpoint[scales_key].dtype == scales_precision + return True + return False + + def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: + new_linear = Int8DynActInt4WeightLinear( + child.in_features, + child.out_features, + bias=False, + device=child.weight.device, + groupsize=group_size, + precision=precision, + scales_precision=scales_precision, + ) + return new_linear + + _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) + + +def transform_linear_for_pre_quantization( + module: torch.nn.Module, + checkpoint: Any, + group_size: int, + dtype: torch.dtype, +) -> torch.nn.Module: + """ + Transform the model to be able to load pre-quantized checkpoints that + are quantized with the given group size and quantization mode for + linear layers. + """ + + if group_size not in [32, 64, 128, 256]: + raise ValueError( + f"Group size {group_size} is not supported for pre-quantized checkpoint." + ) + _replace_linear_with_linear_8da4w_for_pre_quantization( + module, + checkpoint, + group_size, + dtype, + dtype, + ) + return module + + +def _replace_output_linear_with_linear_int8_for_pre_quantization( + module: torch.nn.Module, + checkpoint: Any, + dtype: torch.dtype, +): + def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: + scales_key = f"{cur_fqn}.scales" + if ( + isinstance(child, nn.Linear) + and scales_key in checkpoint + and "output" in cur_fqn + ): + assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8 + assert checkpoint[scales_key].dtype == dtype + return True + return False + + def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: + new_linear = Int8DynActInt8WeightLinear( + device=child.weight.device, + in_features=child.in_features, + out_features=child.out_features, + precision=dtype, + bias=False, + ) + return new_linear + + _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) + + +def transform_output_linear_for_pre_quantization( + module: torch.nn.Module, + checkpoint: Any, + dtype: torch.dtype, +) -> torch.nn.Module: + """ + Transform the model to be able to load pre-quantized checkpoints that + has the output layer quantized per-channel. + """ + _replace_output_linear_with_linear_int8_for_pre_quantization( + module, + checkpoint, + dtype, + ) + return module + + +def _replace_embedding_with_quantized_group_embedding_for_pre_quantization( + module: torch.nn.Module, + checkpoint: Any, + dtype: torch.dtype, + bit_width: int, + group_size: Optional[int] = None, +): + def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: + # Only replace embedding layers where the checkpoint contains explicit scales + scales_key = f"{cur_fqn}.scales" + if isinstance(child, nn.Embedding) and scales_key in checkpoint: + assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8 + assert checkpoint[scales_key].dtype == torch.float32 + return True + return False + + def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: + new_embedding = QuantizedGroupEmbedding( + device=child.weight.device, + vocab_size=child.weight.shape[0], + embedding_dim=child.weight.shape[1], + group_size=group_size, + dtype=dtype, + packed=False, # TODO(lunwenh): support packed embedding for pre-quantized + ) + return new_embedding + + _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) + + +def transform_embedding_for_pre_quantization( + module: torch.nn.Module, + checkpoint: Any, + dtype: torch.dtype, + bit_width: int, + group_size: Optional[int] = None, +) -> torch.nn.Module: + """ + Transform the model to be able to load pre-quantized checkpoints that + are quantized with the given bit_width and group size for embedding. + """ + if group_size is not None and group_size not in [0, 32, 64, 128, 256]: + raise ValueError( + f"Group size {group_size} is not supported for pre-quantized checkpoint." + ) + _replace_embedding_with_quantized_group_embedding_for_pre_quantization( + module, + checkpoint, + dtype, + bit_width, + group_size, + ) + return module + + +def sanitize_checkpoint_from_pre_quantization( + checkpoint: Any, +): + """ + Sanitize the pre-quantized checkpoint. + - Converts all tensors to contiguous format + - Squeeze all tensors + """ + for k, v in checkpoint.items(): + checkpoint[k] = torch.squeeze(v.contiguous()) diff --git a/examples/models/llama2/source_transformation/spin_quant.py b/examples/models/llama2/source_transformation/spin_quant.py index f579e1352eb..f544e9e1f6e 100644 --- a/examples/models/llama2/source_transformation/spin_quant.py +++ b/examples/models/llama2/source_transformation/spin_quant.py @@ -9,7 +9,6 @@ # Helper functions for tranforming the model to be able to run SpinQuant. # See https://github.com/facebookresearch/SpinQuant for more details about SpinQuant. -from typing import Any, Optional import torch @@ -17,10 +16,6 @@ from executorch.examples.models.llama2.llama_transformer import FeedForward from torch import nn -from torchao.quantization.GPTQ import _check_linear_int4_k, Int8DynActInt4WeightLinear -from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter - -from .quantize import Int8DynActInt8WeightLinear, QuantizedGroupEmbedding def _inject_fast_hadamard_transform_cuda_for_spin_quant(module: torch.nn.Module): @@ -91,171 +86,3 @@ def inject_fast_hadamard_transform_native_for_spin_quant( ) -> torch.nn.Module: _inject_fast_hadamard_transform_native_for_spin_quant(module) return module - - -def _replace_linear_with_linear_8da4w_for_spin_quant( - module: torch.nn.Module, - checkpoint: Any, - group_size: int, - precision: torch.dtype, - scales_precision: torch.dtype, -): - def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: - # Only replace linear layers where the checkpoint contains explicit scales - scales_key = f"{cur_fqn}.scales" - if isinstance(child, nn.Linear) and scales_key in checkpoint: - assert _check_linear_int4_k(child.in_features, group_size) - assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8 - assert checkpoint[scales_key].dtype == scales_precision - return True - return False - - def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: - new_linear = Int8DynActInt4WeightLinear( - child.in_features, - child.out_features, - bias=False, - device=child.weight.device, - groupsize=group_size, - precision=precision, - scales_precision=scales_precision, - ) - return new_linear - - _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) - - -def transform_linear_for_spinquant( - module: torch.nn.Module, - checkpoint: Any, - group_size: int, - dtype: torch.dtype, -) -> torch.nn.Module: - """ - Transform the model to be able to load SpinQuant checkpoints that - are quantized with the given group size and quantization mode for - linear layers. - """ - - if group_size not in [32, 64, 128, 256]: - raise ValueError(f"Group size {group_size} is not supported for SpinQuant.") - _replace_linear_with_linear_8da4w_for_spin_quant( - module, - checkpoint, - group_size, - dtype, - dtype, - ) - return module - - -def _replace_output_linear_with_linear_int8_for_spinquant( - module: torch.nn.Module, - checkpoint: Any, - dtype: torch.dtype, -): - def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: - scales_key = f"{cur_fqn}.scales" - if ( - isinstance(child, nn.Linear) - and scales_key in checkpoint - and "output" in cur_fqn - ): - assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8 - assert checkpoint[scales_key].dtype == dtype - return True - return False - - def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: - new_linear = Int8DynActInt8WeightLinear( - device=child.weight.device, - in_features=child.in_features, - out_features=child.out_features, - precision=dtype, - bias=False, - ) - return new_linear - - _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) - - -def transform_output_linear_for_spinquant( - module: torch.nn.Module, - checkpoint: Any, - dtype: torch.dtype, -) -> torch.nn.Module: - """ - Transform the model to be able to load SpinQuant checkpoints that - has the output layer quantized per-channel. - """ - _replace_output_linear_with_linear_int8_for_spinquant( - module, - checkpoint, - dtype, - ) - return module - - -def _replace_embedding_with_quantized_group_embedding_for_spinquant( - module: torch.nn.Module, - checkpoint: Any, - dtype: torch.dtype, - bit_width: int, - group_size: Optional[int] = None, -): - def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: - # Only replace embedding layers where the checkpoint contains explicit scales - scales_key = f"{cur_fqn}.scales" - if isinstance(child, nn.Embedding) and scales_key in checkpoint: - assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8 - assert checkpoint[scales_key].dtype == torch.float32 - return True - return False - - def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: - new_embedding = QuantizedGroupEmbedding( - device=child.weight.device, - vocab_size=child.weight.shape[0], - embedding_dim=child.weight.shape[1], - group_size=group_size, - dtype=dtype, - packed=False, # TODO(lunwenh): support packed embedding for SpinQuant - ) - return new_embedding - - _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) - - -def transform_embedding_for_spinquant( - module: torch.nn.Module, - checkpoint: Any, - dtype: torch.dtype, - bit_width: int, - group_size: Optional[int] = None, -) -> torch.nn.Module: - """ - Transform the model to be able to load SpinQuant checkpoints that - are quantized with the given bit_width and group size for embedding. - """ - if group_size is not None and group_size not in [0, 32, 64, 128, 256]: - raise ValueError(f"Group size {group_size} is not supported for SpinQuant.") - _replace_embedding_with_quantized_group_embedding_for_spinquant( - module, - checkpoint, - dtype, - bit_width, - group_size, - ) - return module - - -def sanitize_checkpoint_from_spinquant( - checkpoint: Any, -): - """ - Sanitize the SpinQuant checkpoint. - - Converts all tensors to contiguous format - - Squeeze all tensors - """ - for k, v in checkpoint.items(): - checkpoint[k] = torch.squeeze(v.contiguous()) diff --git a/examples/models/llama2/tests/TARGETS b/examples/models/llama2/tests/TARGETS index 76981d8f317..2e4dcf7d1f6 100644 --- a/examples/models/llama2/tests/TARGETS +++ b/examples/models/llama2/tests/TARGETS @@ -15,9 +15,9 @@ python_unittest( ) python_unittest( - name = "test_spinquant_transforms", + name = "test_pre_quantization_transforms", srcs = [ - "test_spinquant_transforms.py", + "test_pre_quantization_transforms.py", ], deps = [ "//caffe2:torch", diff --git a/examples/models/llama2/tests/test_spinquant_transforms.py b/examples/models/llama2/tests/test_pre_quantization_transforms.py similarity index 86% rename from examples/models/llama2/tests/test_spinquant_transforms.py rename to examples/models/llama2/tests/test_pre_quantization_transforms.py index 4f6306814b6..59cec2e72ab 100644 --- a/examples/models/llama2/tests/test_spinquant_transforms.py +++ b/examples/models/llama2/tests/test_pre_quantization_transforms.py @@ -8,19 +8,19 @@ import torch from executorch.examples.models.llama2.llama_transformer import ModelArgs, Transformer +from executorch.examples.models.llama2.source_transformation.pre_quantization import ( + sanitize_checkpoint_from_pre_quantization, + transform_embedding_for_pre_quantization, + transform_linear_for_pre_quantization, + transform_output_linear_for_pre_quantization, +) from executorch.examples.models.llama2.source_transformation.quantize import ( dynamically_quantize_per_channel, ) -from executorch.examples.models.llama2.source_transformation.spin_quant import ( - sanitize_checkpoint_from_spinquant, - transform_embedding_for_spinquant, - transform_linear_for_spinquant, - transform_output_linear_for_spinquant, -) from torchao.quantization.utils import group_quantize_tensor_symmetric -class SpinQuantTests(unittest.TestCase): +class PreQuantizationTests(unittest.TestCase): def _prepare_dummy_model(self) -> Transformer: model_args = ModelArgs( @@ -42,7 +42,7 @@ def _prepare_dummy_model(self) -> Transformer: return model - def test_transform_linear_for_spinquant(self): + def test_transform_linear_for_pre_quantization(self): # Step 1: Create llama class with dummy weights model = self._prepare_dummy_model() @@ -69,14 +69,13 @@ def test_transform_linear_for_spinquant(self): # Step 3: # Transform the model so that it is compatible with the new checkpoint - transform_linear_for_spinquant( + transform_linear_for_pre_quantization( model, checkpoint, 32, - "8da4w", torch.float32, ) - sanitize_checkpoint_from_spinquant(checkpoint) + sanitize_checkpoint_from_pre_quantization(checkpoint) model.load_state_dict( checkpoint, @@ -91,7 +90,7 @@ def test_transform_linear_for_spinquant(self): # have to iterate over the keys. self.assertTrue(torch.allclose(new_checkpoint[k], v)) - def test_transform_output_linear_for_spinquant(self): + def test_transform_output_linear_for_pre_quantization(self): # Step 1: Create llama class with dummy weights model = self._prepare_dummy_model() checkpoint = model.state_dict() @@ -114,12 +113,12 @@ def test_transform_output_linear_for_spinquant(self): # Step 3: # Transform the model so that it is compatible with the new checkpoint - transform_output_linear_for_spinquant( + transform_output_linear_for_pre_quantization( model, checkpoint, torch.float32, ) - sanitize_checkpoint_from_spinquant(checkpoint) + sanitize_checkpoint_from_pre_quantization(checkpoint) model.load_state_dict( checkpoint, @@ -134,7 +133,7 @@ def test_transform_output_linear_for_spinquant(self): # have to iterate over the keys. self.assertTrue(torch.allclose(new_checkpoint[k], v)) - def test_transform_embedding_for_spinquant(self): + def test_transform_embedding_for_pre_quantization(self): # Step 1: Create llama class with dummy weights model = self._prepare_dummy_model() @@ -162,14 +161,14 @@ def test_transform_embedding_for_spinquant(self): # Step 3: # Transform the model so that it is compatible with the new checkpoint - transform_embedding_for_spinquant( + transform_embedding_for_pre_quantization( model, checkpoint, torch.float32, n_bit, group_size, ) - sanitize_checkpoint_from_spinquant(checkpoint) + sanitize_checkpoint_from_pre_quantization(checkpoint) model.load_state_dict( checkpoint, diff --git a/pytest.ini b/pytest.ini index 701c0187ecf..dc77a910a1f 100644 --- a/pytest.ini +++ b/pytest.ini @@ -39,7 +39,7 @@ addopts = --ignore=backends/xnnpack/test/ops/linear.py --ignore=backends/xnnpack/test/models/llama2_et_example.py # T200992559: Add torchao to ET as core dependency - --ignore=examples/models/llama2/tests/test_spinquant_transforms.py + --ignore=examples/models/llama2/tests/test_pre_quantization_transforms.py --ignore=exir/backend/test/demos --ignore=exir/backend/test/test_backends.py --ignore=exir/backend/test/test_backends_lifted.py