From be1921ce63ce580f4bb86f37ef1aa4312d37dac5 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Fri, 21 Mar 2025 18:10:13 -0700 Subject: [PATCH] Fix xnnpack quantization discrepancy for non-fp32 (#8488) Summary: Perform quantization on the weights expressed in their original dtype (from the checkpoint) by passing in the checkpoint dtype to the quantization source transformation and modifying the computation dtype (the result dtype of the dequant, the dtype that the ops are actually computed in) to the dtype override. We must do it this way since the checkpoint and computation dtype are coupled into a single `precision` parameter in the torchao api, and that is something that we cannot change. Note - no need to worry about https://github.com/pytorch/ao/blob/main/torchao/quantization/GPTQ.py#L1168, precision is passed in with the checkpoint dtype ### Comparison of arbitrary q_proj tensor from sample Llama checkpoint: Before: ``` Mismatched elements: 3260378 / 4194304 (77.7%) Greatest absolute difference: 0.08802086114883423 at index (1129, 604) (up to 1e-05 allowed) Greatest relative difference: 1.0 at index (0, 1350) (up to 1.3e-06 allowed) Signal-to-noise: 32.8974 dB ``` After: no difference Test Plan: ### Manual testing ``` python -m examples.models.llama.export_llama \ -v -c xl_consolidated/consolidated_renamed.pth \ -p xl_consolidated/et_params.json -kv -d fp32 \ -qmode 8da4w --group_size 32 -X \ --use_sdpa_with_kv_cache \ --output_name quantized_baseline.pte \ --max_context_length 4096 -E 4,32 ``` With the following inserted after the quantization: ``` edge_manager.model( torch.tensor([[2, 3, 4]], dtype=torch.long), {"input_pos": torch.tensor([0], dtype=torch.long)}, ) ``` And the following modifications to GPTQ.py in torchao: https://github.com/pytorch/ao/pull/1756 for testing. ### Automated testing + existing CI tests ### Regression testing TBD Reviewed By: kimishpatel Differential Revision: D70184325 Pulled By: jackzhxng --- examples/models/checkpoint.py | 4 +- examples/models/llama/export_llama_lib.py | 67 +++++++++++-- .../llama/source_transformation/quantize.py | 99 +++++++++++++++---- examples/models/llava/export_llava.py | 2 +- exir/tests/test_memory_planning.py | 4 +- extension/llm/export/builder.py | 11 +++ 6 files changed, 159 insertions(+), 28 deletions(-) diff --git a/examples/models/checkpoint.py b/examples/models/checkpoint.py index c84a689b951..57a5b0ffaca 100644 --- a/examples/models/checkpoint.py +++ b/examples/models/checkpoint.py @@ -9,6 +9,8 @@ from pathlib import Path from typing import Any, Dict, Optional +import torch + def get_default_model_resource_dir(model_file_path: str) -> Path: """ @@ -52,7 +54,7 @@ def get_default_model_resource_dir(model_file_path: str) -> Path: return resource_dir -def get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[str]: +def get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[torch.dtype]: """ Get the dtype of the checkpoint, returning "None" if the checkpoint is empty. """ diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index cdcd5f89635..4de04b1ed11 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -16,6 +16,7 @@ import re import shlex from enum import Enum +from functools import partial from json import JSONDecodeError from pathlib import Path from typing import Callable, List, Optional, Union @@ -594,9 +595,36 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: ) # At this point, the model is loaded in the default fp32. + + # Checkpoint dtype should be lower or equal precision to the dtype override. + checkpoint_dtype = edge_manager.model.checkpoint_dtype + if not ( + checkpoint_dtype == dtype_override.to_torch_dtype() + or ( + checkpoint_dtype == torch.float16 + and dtype_override.to_torch_dtype() == torch.float32 + ) + or ( + checkpoint_dtype == torch.bfloat16 + and dtype_override.to_torch_dtype() == torch.float32 + ) + ): + logging.warning( + f"Checkpoint dtype {checkpoint_dtype} precision is higher than dtype override {dtype_override.to_torch_dtype()}." + ) + edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype()) - edge_manager.set_output_dir(output_dir_path).source_transform( - _get_source_transforms(args.model, dtype_override, args) + + # We want to quantize (in the source transforms) the weights of the model + # in the checkpoint dtype. + logging.info(f"Checkpoint dtype: {edge_manager.model.checkpoint_dtype}") + edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform( + _get_source_transforms( + modelname=args.model, + dtype_override=dtype_override, + checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype), + args=args, + ) ) return edge_manager @@ -784,8 +812,6 @@ def _to_edge_and_lower_llama( # noqa: C901 shares=args.num_sharding, ) - from functools import partial - # pyre-ignore from executorch.backends.qualcomm.quantizer.custom_annotation import ( get_custom_quant_ios_dtype, @@ -1069,8 +1095,31 @@ def _load_llama_model( def _get_source_transforms( # noqa - modelname: str, dtype_override: Optional[DType], args + modelname: str, + dtype_override: DType, + *, + checkpoint_dtype: Optional[DType] = None, + args, ) -> List[Callable[[torch.nn.Module], torch.nn.Module]]: + """ + Return a list of functions that transform a graph. + + Args: + modelname: The name of the model. + dtype_override: The dtype to use for the model. + checkpoint_dtype: The dtype of the checkpoint. At the moment, if this is specified, + it means that you want to run quantize transformations on the weights represented + in their original dtype, while the overall dtype of the model maybe something + different. If not specified, defaults to dtype_override. + args: The arguments passed to the script. + + Returns: + A list of transformation functions. + """ + + if not checkpoint_dtype: + checkpoint_dtype = dtype_override + transforms = [] if args.use_spin_quant: @@ -1103,7 +1152,11 @@ def _get_source_transforms( # noqa """ modelname = f"{modelname}_q" transforms.append( - get_quant_weight_transform(args, dtype_override, verbose_export()) + get_quant_weight_transform( + args=args, + computation_dtype=dtype_override, + checkpoint_dtype=checkpoint_dtype, + ) ) if args.embedding_quantize: @@ -1117,7 +1170,7 @@ def _get_source_transforms( # noqa this wil be a no-op. """ modelname = f"{modelname}_e" - transforms.append(get_quant_embedding_transform(args)) + transforms.append(get_quant_embedding_transform(args, checkpoint_dtype)) if args.expand_rope_table: transforms.append(materialze_broadcast_of_rope_freq_cis) diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index 9f17c54b73e..17cff7c63fd 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -18,6 +18,7 @@ from sentencepiece import SentencePieceProcessor + try: from fairseq2.nn.embedding import ( Embedding as fsEmbedding, @@ -36,7 +37,8 @@ def quantize( # noqa C901 model: torch.nn.Module, qmode: str, - activation_dtype: Optional[DType], + computation_dtype: Optional[DType] = None, + checkpoint_dtype: Optional[DType] = None, checkpoint_path: Optional[Path] = None, # following arguments only available when setting int4 or gptq quantization. group_size: Optional[int] = 128, @@ -52,20 +54,33 @@ def quantize( # noqa C901 ) -> torch.nn.Module: """ Quantizes a model by converting all weights to int8. + Args: - model: A model to quantize. - qmode: quantization mode, e.g. int8, 8da4w, 8da4w-gptq + model: The model to quantize. + qmode: The quantization mode, e.g. int8, 8da4w, 8da4w-gptq. + computation_dtype: The dtype that ops are performed in (the resulting dtype of dequantization). + Also the dtype of the rest of the non-quantized compoents of the model. + checkpoint_dtype: The dtype of the checkpoint, this arg exists since it is more accurate to + quantize the weight in its original dtype. + Returns: A quantized model. """ - if activation_dtype is not None: - torch_dtype = activation_dtype.to_torch_dtype() + if computation_dtype: + computation_torch_dtype = computation_dtype.to_torch_dtype() else: - torch_dtype = torch.float16 + computation_torch_dtype = torch.float32 + + if not checkpoint_dtype: + checkpoint_torch_dtype = computation_torch_dtype + else: + checkpoint_torch_dtype = checkpoint_dtype.to_torch_dtype() if qmode == "int8": # Add quantization mode options here: group size, bit width, etc. - return WeightOnlyInt8QuantHandler(model).quantized_model() + return WeightOnlyInt8QuantHandler( + model, precision=checkpoint_torch_dtype + ).quantized_model() elif qmode.startswith("torchao:fpa"): pattern = r"torchao:fpa(\d+)w" matches = re.findall(pattern, qmode) @@ -75,10 +90,12 @@ def quantize( # noqa C901 from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer with torch.no_grad(): + # This quantize() is currently doing a model.to(self.precision) so cannot + # decouple computation and checkpoint dtypes. model = ( UIntxWeightOnlyLinearQuantizer( device="mps", - precision=torch.float32, + precision=computation_torch_dtype, groupsize=group_size, bitwidth=bitwidth, ) @@ -101,6 +118,8 @@ def quantize( # noqa C901 from torchao.utils import unwrap_tensor_subclass with torch.no_grad(): + # Computation dtype is fixed to fp32 in the implementation of quantize_, so + # no way to decouple checkpoint and computation dtype. quantize_( model, Int8DynamicActivationIntxWeightConfig( @@ -121,9 +140,12 @@ def quantize( # noqa C901 raise Exception("For 8da4w quantization, group size must be specified.") from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer + # 1. Quantize in checkpoint dtype. model = Int8DynActInt4WeightQuantizer( - precision=torch_dtype, groupsize=group_size + precision=checkpoint_torch_dtype, groupsize=group_size ).quantize(model) + # 2. Set the computation dtype (what weights/acts dequantize to). + model = set_8da4w_computation_dtype(model, computation_torch_dtype) if verbose: print("quantized model:", model) @@ -177,7 +199,7 @@ def quantize( # noqa C901 blocksize, percdamp, group_size, - ) + ) # TODO: separate computation and checkpoint dtype for GPTQ. model = gptq_quantizer.quantize(model, inputs) return model elif qmode == "vulkan_4w": @@ -190,9 +212,12 @@ def quantize( # noqa C901 # at the moment from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer + # 1. Quantize in checkpoint dtype. model = Int8DynActInt4WeightQuantizer( - precision=torch_dtype, groupsize=q_group_size + precision=checkpoint_torch_dtype, groupsize=q_group_size ).quantize(model) + # 2. Set the computation dtype (what weights/acts dequantize to). + model = set_8da4w_computation_dtype(model, computation_torch_dtype) return model else: @@ -348,6 +373,7 @@ def __init__( node_type: str = "*", bitwidth: Optional[int] = None, group_size: Optional[int] = None, + precision: torch.dtype = torch.float32, ): self.mod = mod self.group_size = group_size @@ -356,6 +382,7 @@ def __init__( self.bitwidth = 8 else: self.bitwidth = bitwidth + self.precision = precision @torch.no_grad() def create_quantized_state_dict(self) -> Dict: @@ -391,7 +418,7 @@ def create_quantized_state_dict(self) -> Dict: # print(f"expanded weight shape {input_weight.shape}") weight, scales, _ = dynamically_quantize_per_channel( - input_weight, + input_weight.to(dtype=self.precision), range_min, range_max, torch.int8, @@ -576,6 +603,7 @@ def __init__( bitwidth: int = 8, group_size: Optional[int] = None, packed=False, + precision: Optional[torch.dtype] = None, ): if isinstance(packed, str): packed = packed == "True" @@ -584,6 +612,8 @@ def __init__( self.group_size = group_size self.bitwidth = bitwidth self.packed = packed + # Dtype of the weights right before quantization. + self.precision = precision if (bitwidth not in [2, 4]) and packed: raise RuntimeError("pack only works with bitsize 2, 4") @@ -614,7 +644,11 @@ def create_quantized_state_dict(self, packed=False) -> Dict: f"quantize {fqn, mod} with group_size {self.group_size}, bitwidth {self.bitwidth}" ) weight, scales, _ = dynamically_quantize_per_channel( - mod.weight.float(), + ( + mod.weight.to(dtype=self.precision) + if self.precision + else mod.weight + ), range_min, range_max, torch.int8, @@ -750,7 +784,7 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor: ############################ Source Transform Start ####################### -def get_quant_embedding_transform(args): +def get_quant_embedding_transform(args, dtype_override: Optional[DType] = None): if args.embedding_quantize.startswith("torchao:"): bitwidth, group_size = args.embedding_quantize.split(":")[1].split(",") group_size = int(group_size) @@ -775,16 +809,22 @@ def _torchao_embedding_quantizer(model): else: group_size = int(group_size) bitwidth = int(bitwidth) + torch_dtype = dtype_override.to_torch_dtype() if dtype_override else None return lambda model: EmbeddingQuantHandler( model, bitwidth=bitwidth, group_size=group_size, packed=(bitwidth in [2, 4]), + precision=torch_dtype, ).quantized_model() -def get_quant_weight_transform(args, dtype_override, verbose): - # If these optional args are None, don't provide them to quantize() +def get_quant_weight_transform( + args, + computation_dtype: Optional[DType] = None, + checkpoint_dtype: Optional[DType] = None, +): + # If these optional args are None, don't provide them to quantize(). quant_args_str = [ "group_size", "calibration_tasks", @@ -802,7 +842,8 @@ def get_quant_weight_transform(args, dtype_override, verbose): quantize, **quant_args, qmode=args.quantization_mode, - activation_dtype=dtype_override, + computation_dtype=computation_dtype, + checkpoint_dtype=checkpoint_dtype, checkpoint_path=(Path(path) if (path := args.checkpoint) is not None else None), tokenizer_path=( Path(path) if (path := args.tokenizer_path) is not None else None @@ -829,4 +870,28 @@ def _load_torchao_aten_lib(libname): torch.ops.load_library(libs[0]) +# We want to do compute the actual ops in the computation dtype, since the precision of the +# quantized linear will initially be the dtype of the checkpoint. +def set_8da4w_computation_dtype( + module: nn.Module, computation_dtype: torch.dtype +) -> nn.Module: + + from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear + + def _set_8da4w_computation_dtype(module: nn.Module, dtype: torch.dtype) -> None: + """ + Recursively iterate through the module and set the precision attributes + of all Int8DynActInt4WeightLinears. + """ + for _name, child in module.named_children(): + if isinstance(child, Int8DynActInt4WeightLinear): + child.precision = dtype + else: + # Recursively apply to child modules + _set_8da4w_computation_dtype(child, dtype) + + _set_8da4w_computation_dtype(module, computation_dtype) + return module + + ############################ Source Transform End ####################### diff --git a/examples/models/llava/export_llava.py b/examples/models/llava/export_llava.py index a5057e5e850..494cfe7bfa3 100644 --- a/examples/models/llava/export_llava.py +++ b/examples/models/llava/export_llava.py @@ -100,7 +100,7 @@ def forward(self, input_pos, embeddings): args = parser.parse_args( ["-X", "-qmode", "8da4w", "--group_size", "128", "--embedding-quantize", "4,32"] ) - quant_transform = get_quant_weight_transform(args, dtype_override, False) + quant_transform = get_quant_weight_transform(args, dtype_override) _, quantizers, _ = get_quantizer_and_quant_params(args) source_transforms = [] if llava.use_sdpa_with_kv_cache_op: diff --git a/exir/tests/test_memory_planning.py b/exir/tests/test_memory_planning.py index f3cddaff643..d885239acd8 100644 --- a/exir/tests/test_memory_planning.py +++ b/exir/tests/test_memory_planning.py @@ -708,10 +708,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: et_program = et.executorch_program inputs = et_program.execution_plan[0].inputs self.assertNotEqual( - et_program.execution_plan[0] # pyre-ignore + et_program.execution_plan[0] .values[inputs[0]] .val.allocation_info.memory_offset_low, - et_program.execution_plan[0] # pyre-ignore + et_program.execution_plan[0] .values[inputs[1]] .val.allocation_info.memory_offset_low, ) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 9bca126f027..751e2d16175 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -61,6 +61,17 @@ def to_torch_dtype(self) -> torch.dtype: raise ValueError(f"Unsupported dtype {self}") return mapping[self] + @staticmethod + def from_torch_dtype(dtype: torch.dtype): + mapping = { + torch.float32: DType.fp32, + torch.float16: DType.fp16, + torch.bfloat16: DType.bf16, + } + if dtype not in mapping: + raise ValueError(f"Unsupported torch.dtype {dtype}") + return mapping[dtype] + class LLMEdgeManager: """