diff --git a/examples/models/checkpoint.py b/examples/models/checkpoint.py index c84a689b951..57a5b0ffaca 100644 --- a/examples/models/checkpoint.py +++ b/examples/models/checkpoint.py @@ -9,6 +9,8 @@ from pathlib import Path from typing import Any, Dict, Optional +import torch + def get_default_model_resource_dir(model_file_path: str) -> Path: """ @@ -52,7 +54,7 @@ def get_default_model_resource_dir(model_file_path: str) -> Path: return resource_dir -def get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[str]: +def get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[torch.dtype]: """ Get the dtype of the checkpoint, returning "None" if the checkpoint is empty. """ diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index cdcd5f89635..4de04b1ed11 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -16,6 +16,7 @@ import re import shlex from enum import Enum +from functools import partial from json import JSONDecodeError from pathlib import Path from typing import Callable, List, Optional, Union @@ -594,9 +595,36 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: ) # At this point, the model is loaded in the default fp32. + + # Checkpoint dtype should be lower or equal precision to the dtype override. + checkpoint_dtype = edge_manager.model.checkpoint_dtype + if not ( + checkpoint_dtype == dtype_override.to_torch_dtype() + or ( + checkpoint_dtype == torch.float16 + and dtype_override.to_torch_dtype() == torch.float32 + ) + or ( + checkpoint_dtype == torch.bfloat16 + and dtype_override.to_torch_dtype() == torch.float32 + ) + ): + logging.warning( + f"Checkpoint dtype {checkpoint_dtype} precision is higher than dtype override {dtype_override.to_torch_dtype()}." + ) + edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype()) - edge_manager.set_output_dir(output_dir_path).source_transform( - _get_source_transforms(args.model, dtype_override, args) + + # We want to quantize (in the source transforms) the weights of the model + # in the checkpoint dtype. + logging.info(f"Checkpoint dtype: {edge_manager.model.checkpoint_dtype}") + edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform( + _get_source_transforms( + modelname=args.model, + dtype_override=dtype_override, + checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype), + args=args, + ) ) return edge_manager @@ -784,8 +812,6 @@ def _to_edge_and_lower_llama( # noqa: C901 shares=args.num_sharding, ) - from functools import partial - # pyre-ignore from executorch.backends.qualcomm.quantizer.custom_annotation import ( get_custom_quant_ios_dtype, @@ -1069,8 +1095,31 @@ def _load_llama_model( def _get_source_transforms( # noqa - modelname: str, dtype_override: Optional[DType], args + modelname: str, + dtype_override: DType, + *, + checkpoint_dtype: Optional[DType] = None, + args, ) -> List[Callable[[torch.nn.Module], torch.nn.Module]]: + """ + Return a list of functions that transform a graph. + + Args: + modelname: The name of the model. + dtype_override: The dtype to use for the model. + checkpoint_dtype: The dtype of the checkpoint. At the moment, if this is specified, + it means that you want to run quantize transformations on the weights represented + in their original dtype, while the overall dtype of the model maybe something + different. If not specified, defaults to dtype_override. + args: The arguments passed to the script. + + Returns: + A list of transformation functions. + """ + + if not checkpoint_dtype: + checkpoint_dtype = dtype_override + transforms = [] if args.use_spin_quant: @@ -1103,7 +1152,11 @@ def _get_source_transforms( # noqa """ modelname = f"{modelname}_q" transforms.append( - get_quant_weight_transform(args, dtype_override, verbose_export()) + get_quant_weight_transform( + args=args, + computation_dtype=dtype_override, + checkpoint_dtype=checkpoint_dtype, + ) ) if args.embedding_quantize: @@ -1117,7 +1170,7 @@ def _get_source_transforms( # noqa this wil be a no-op. """ modelname = f"{modelname}_e" - transforms.append(get_quant_embedding_transform(args)) + transforms.append(get_quant_embedding_transform(args, checkpoint_dtype)) if args.expand_rope_table: transforms.append(materialze_broadcast_of_rope_freq_cis) diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index 9f17c54b73e..17cff7c63fd 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -18,6 +18,7 @@ from sentencepiece import SentencePieceProcessor + try: from fairseq2.nn.embedding import ( Embedding as fsEmbedding, @@ -36,7 +37,8 @@ def quantize( # noqa C901 model: torch.nn.Module, qmode: str, - activation_dtype: Optional[DType], + computation_dtype: Optional[DType] = None, + checkpoint_dtype: Optional[DType] = None, checkpoint_path: Optional[Path] = None, # following arguments only available when setting int4 or gptq quantization. group_size: Optional[int] = 128, @@ -52,20 +54,33 @@ def quantize( # noqa C901 ) -> torch.nn.Module: """ Quantizes a model by converting all weights to int8. + Args: - model: A model to quantize. - qmode: quantization mode, e.g. int8, 8da4w, 8da4w-gptq + model: The model to quantize. + qmode: The quantization mode, e.g. int8, 8da4w, 8da4w-gptq. + computation_dtype: The dtype that ops are performed in (the resulting dtype of dequantization). + Also the dtype of the rest of the non-quantized compoents of the model. + checkpoint_dtype: The dtype of the checkpoint, this arg exists since it is more accurate to + quantize the weight in its original dtype. + Returns: A quantized model. """ - if activation_dtype is not None: - torch_dtype = activation_dtype.to_torch_dtype() + if computation_dtype: + computation_torch_dtype = computation_dtype.to_torch_dtype() else: - torch_dtype = torch.float16 + computation_torch_dtype = torch.float32 + + if not checkpoint_dtype: + checkpoint_torch_dtype = computation_torch_dtype + else: + checkpoint_torch_dtype = checkpoint_dtype.to_torch_dtype() if qmode == "int8": # Add quantization mode options here: group size, bit width, etc. - return WeightOnlyInt8QuantHandler(model).quantized_model() + return WeightOnlyInt8QuantHandler( + model, precision=checkpoint_torch_dtype + ).quantized_model() elif qmode.startswith("torchao:fpa"): pattern = r"torchao:fpa(\d+)w" matches = re.findall(pattern, qmode) @@ -75,10 +90,12 @@ def quantize( # noqa C901 from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer with torch.no_grad(): + # This quantize() is currently doing a model.to(self.precision) so cannot + # decouple computation and checkpoint dtypes. model = ( UIntxWeightOnlyLinearQuantizer( device="mps", - precision=torch.float32, + precision=computation_torch_dtype, groupsize=group_size, bitwidth=bitwidth, ) @@ -101,6 +118,8 @@ def quantize( # noqa C901 from torchao.utils import unwrap_tensor_subclass with torch.no_grad(): + # Computation dtype is fixed to fp32 in the implementation of quantize_, so + # no way to decouple checkpoint and computation dtype. quantize_( model, Int8DynamicActivationIntxWeightConfig( @@ -121,9 +140,12 @@ def quantize( # noqa C901 raise Exception("For 8da4w quantization, group size must be specified.") from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer + # 1. Quantize in checkpoint dtype. model = Int8DynActInt4WeightQuantizer( - precision=torch_dtype, groupsize=group_size + precision=checkpoint_torch_dtype, groupsize=group_size ).quantize(model) + # 2. Set the computation dtype (what weights/acts dequantize to). + model = set_8da4w_computation_dtype(model, computation_torch_dtype) if verbose: print("quantized model:", model) @@ -177,7 +199,7 @@ def quantize( # noqa C901 blocksize, percdamp, group_size, - ) + ) # TODO: separate computation and checkpoint dtype for GPTQ. model = gptq_quantizer.quantize(model, inputs) return model elif qmode == "vulkan_4w": @@ -190,9 +212,12 @@ def quantize( # noqa C901 # at the moment from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer + # 1. Quantize in checkpoint dtype. model = Int8DynActInt4WeightQuantizer( - precision=torch_dtype, groupsize=q_group_size + precision=checkpoint_torch_dtype, groupsize=q_group_size ).quantize(model) + # 2. Set the computation dtype (what weights/acts dequantize to). + model = set_8da4w_computation_dtype(model, computation_torch_dtype) return model else: @@ -348,6 +373,7 @@ def __init__( node_type: str = "*", bitwidth: Optional[int] = None, group_size: Optional[int] = None, + precision: torch.dtype = torch.float32, ): self.mod = mod self.group_size = group_size @@ -356,6 +382,7 @@ def __init__( self.bitwidth = 8 else: self.bitwidth = bitwidth + self.precision = precision @torch.no_grad() def create_quantized_state_dict(self) -> Dict: @@ -391,7 +418,7 @@ def create_quantized_state_dict(self) -> Dict: # print(f"expanded weight shape {input_weight.shape}") weight, scales, _ = dynamically_quantize_per_channel( - input_weight, + input_weight.to(dtype=self.precision), range_min, range_max, torch.int8, @@ -576,6 +603,7 @@ def __init__( bitwidth: int = 8, group_size: Optional[int] = None, packed=False, + precision: Optional[torch.dtype] = None, ): if isinstance(packed, str): packed = packed == "True" @@ -584,6 +612,8 @@ def __init__( self.group_size = group_size self.bitwidth = bitwidth self.packed = packed + # Dtype of the weights right before quantization. + self.precision = precision if (bitwidth not in [2, 4]) and packed: raise RuntimeError("pack only works with bitsize 2, 4") @@ -614,7 +644,11 @@ def create_quantized_state_dict(self, packed=False) -> Dict: f"quantize {fqn, mod} with group_size {self.group_size}, bitwidth {self.bitwidth}" ) weight, scales, _ = dynamically_quantize_per_channel( - mod.weight.float(), + ( + mod.weight.to(dtype=self.precision) + if self.precision + else mod.weight + ), range_min, range_max, torch.int8, @@ -750,7 +784,7 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor: ############################ Source Transform Start ####################### -def get_quant_embedding_transform(args): +def get_quant_embedding_transform(args, dtype_override: Optional[DType] = None): if args.embedding_quantize.startswith("torchao:"): bitwidth, group_size = args.embedding_quantize.split(":")[1].split(",") group_size = int(group_size) @@ -775,16 +809,22 @@ def _torchao_embedding_quantizer(model): else: group_size = int(group_size) bitwidth = int(bitwidth) + torch_dtype = dtype_override.to_torch_dtype() if dtype_override else None return lambda model: EmbeddingQuantHandler( model, bitwidth=bitwidth, group_size=group_size, packed=(bitwidth in [2, 4]), + precision=torch_dtype, ).quantized_model() -def get_quant_weight_transform(args, dtype_override, verbose): - # If these optional args are None, don't provide them to quantize() +def get_quant_weight_transform( + args, + computation_dtype: Optional[DType] = None, + checkpoint_dtype: Optional[DType] = None, +): + # If these optional args are None, don't provide them to quantize(). quant_args_str = [ "group_size", "calibration_tasks", @@ -802,7 +842,8 @@ def get_quant_weight_transform(args, dtype_override, verbose): quantize, **quant_args, qmode=args.quantization_mode, - activation_dtype=dtype_override, + computation_dtype=computation_dtype, + checkpoint_dtype=checkpoint_dtype, checkpoint_path=(Path(path) if (path := args.checkpoint) is not None else None), tokenizer_path=( Path(path) if (path := args.tokenizer_path) is not None else None @@ -829,4 +870,28 @@ def _load_torchao_aten_lib(libname): torch.ops.load_library(libs[0]) +# We want to do compute the actual ops in the computation dtype, since the precision of the +# quantized linear will initially be the dtype of the checkpoint. +def set_8da4w_computation_dtype( + module: nn.Module, computation_dtype: torch.dtype +) -> nn.Module: + + from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear + + def _set_8da4w_computation_dtype(module: nn.Module, dtype: torch.dtype) -> None: + """ + Recursively iterate through the module and set the precision attributes + of all Int8DynActInt4WeightLinears. + """ + for _name, child in module.named_children(): + if isinstance(child, Int8DynActInt4WeightLinear): + child.precision = dtype + else: + # Recursively apply to child modules + _set_8da4w_computation_dtype(child, dtype) + + _set_8da4w_computation_dtype(module, computation_dtype) + return module + + ############################ Source Transform End ####################### diff --git a/examples/models/llava/export_llava.py b/examples/models/llava/export_llava.py index a5057e5e850..494cfe7bfa3 100644 --- a/examples/models/llava/export_llava.py +++ b/examples/models/llava/export_llava.py @@ -100,7 +100,7 @@ def forward(self, input_pos, embeddings): args = parser.parse_args( ["-X", "-qmode", "8da4w", "--group_size", "128", "--embedding-quantize", "4,32"] ) - quant_transform = get_quant_weight_transform(args, dtype_override, False) + quant_transform = get_quant_weight_transform(args, dtype_override) _, quantizers, _ = get_quantizer_and_quant_params(args) source_transforms = [] if llava.use_sdpa_with_kv_cache_op: diff --git a/exir/tests/test_memory_planning.py b/exir/tests/test_memory_planning.py index f3cddaff643..d885239acd8 100644 --- a/exir/tests/test_memory_planning.py +++ b/exir/tests/test_memory_planning.py @@ -708,10 +708,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: et_program = et.executorch_program inputs = et_program.execution_plan[0].inputs self.assertNotEqual( - et_program.execution_plan[0] # pyre-ignore + et_program.execution_plan[0] .values[inputs[0]] .val.allocation_info.memory_offset_low, - et_program.execution_plan[0] # pyre-ignore + et_program.execution_plan[0] .values[inputs[1]] .val.allocation_info.memory_offset_low, ) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 9bca126f027..751e2d16175 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -61,6 +61,17 @@ def to_torch_dtype(self) -> torch.dtype: raise ValueError(f"Unsupported dtype {self}") return mapping[self] + @staticmethod + def from_torch_dtype(dtype: torch.dtype): + mapping = { + torch.float32: DType.fp32, + torch.float16: DType.fp16, + torch.bfloat16: DType.bf16, + } + if dtype not in mapping: + raise ValueError(f"Unsupported torch.dtype {dtype}") + return mapping[dtype] + class LLMEdgeManager: """