diff --git a/docs/source/api_ref_utils.rst b/docs/source/api_ref_utils.rst new file mode 100644 index 0000000000..3e85bbb424 --- /dev/null +++ b/docs/source/api_ref_utils.rst @@ -0,0 +1,33 @@ +.. _api_utils: + + +============= +torchao.utils +============= + +.. currentmodule:: torchao.utils + +Tensor Subclass Utils +--------------------- +.. autosummary:: + :toctree: generated/ + :nosignatures: + + TorchAOBaseTensor + +===================================== +torchao.quantization.quantize_.common +===================================== + +.. currentmodule:: torchao.quantization.quantize_.common + +quantize_ API Common Utils +-------------------------- +.. autosummary:: + :toctree: generated/ + :nosignatures: + + KernelPreference + PackingFormat + QuantizeTensorKwargs + _choose_quant_func_and_quantize_tensor diff --git a/docs/source/contributor_guide.rst b/docs/source/contributor_guide.rst index ab6d433e27..353ba754ca 100644 --- a/docs/source/contributor_guide.rst +++ b/docs/source/contributor_guide.rst @@ -4,16 +4,34 @@ Contributor Guide General Guide on Extending torchao ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -For a new use case, for example, a training dtype (like fp4 training), it's fine to start with adding a new tensor subclass in prototype folder `torchao/prototype `__, but you could also take a look at ``AffineQuantizedTensor`` if what you want to do is mostly supported there, e.g. adding int3 kernel for the exact same affine quantization. Please feel free to open an issue and if you have questions on what to do for a specific new use case. For more details, please refer to our `quantization overview page `__. +Please start by reading our `quantization overview page `__ first. To contribute to existing code base: -* Adding features to AffineQuantizedTensor, e.g. making it trainable, add tensor parallelism support etc.: `torchao/dtypes/affine_quantized_tensor.py `__ +* Adding a new Tensor: `torchao/quantization/quantize_/workflows `__ * Adding new quantization APIs: `torchao/quantization/quant_api.py `__ +* Adding features to existing Tensor subclasses like ``Float8Tensor``, e.g. adding new operator support, making it trainable, add tensor parallelism support etc., `tensor subclasses `__, `tests `__ * Adding new quantization primitive ops, e.g. slight variations of existing quantization primitive ops: `torchao/quantization/quant_primitives.py `__ * Adding new autotuned triton kernels: `torchao/kernel `__ * Adding new custom cpu/cuda/mps kernels: `torchao/csrc `__ -* Integrating custom kernel with AffineQuantizedTensor (maybe a new layout as well): Add sparse marlin AQT layout `#621 `__ as an example. We are still not decided if we want to split ``AffineQuantizedTensor`` to more tensor subclasses or not. + +Adding New Tensor Subclasses +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +torchao Tensor subclasses are structured by ``derived dtype`` and ``packing format``, please check out the `quantization overview page `__ to understand these concepts. If a new tensor subclass is needed for your use case, i.e. a new dtype, or a new packing format that does not already exist, we could define a new Tensor. + +To understand how to use tensor subclass in the context of quantization, please also check `Writing Your Own Quantized Tensor `__. + +We have utility base class: ``torchao.utils.TorchAOBaseTensor`` that can help define common util functions and methods for you, if you specified the names of Tensor and non-Tensor attributes of the tensor subclass. for example:: + + class MyTensor(TorchAOBaseTensor): + tensor_data_names = ["qdata", "scale"] + tensor_attribute_names = ["device", "dtype"] + + +With the above, we'll have multiple methods and functions available to use for this Tensor, for more details please check the docs for `TorchAOBaseTensor `__ + +.. note:: + Many of the existing use cases in torchao still uses AffineQuantizedTensor, but we plan to move away from it to reduce the abstractions and make it easier for people to contribute to torchao. Adding Efficient Kernels ~~~~~~~~~~~~~~~~~~~~~~~~ @@ -31,50 +49,59 @@ Custom hand written kernels ########################### Custom kernels (implementations) for cpu/cuda/mps can be implemented through `torchao/csrc `__ e.g. int4 cuda, and accessible through torch.ops.my_custom_op -Dispatches -~~~~~~~~~~ +Using hand written kernels in Tensor Subclasses +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +For calling optimized kernels, we have ``implements`` from the tensor subclass, for example, if we want to call into a new custom op: ``torch.ops.torchao.my_mm_for_mps``:: + + class Float8Tensor(TorchAOBaseTensor): + ... + + implements = Float8Tensor.implements -For dispatching to optimized kernels for cpu/cuda/mps devices, we can have checks for the dispatch conditions in ``__torch_function__`` or ``__torch_dispatch__`` and dispatch to target operators, for example, condition for bfloat16 activation and uint4 weight kernel can be found `here `__. + @implements([torch.nn.functional.linear, aten.linear.default]) + def _(func, types, args, kwargs): + ... + # call into the custom op + res = torch.ops.torchao.my_mm_for_mps(input_tensor.qdata, weight_tensor.qdata, input_tensor.scale, weight_tensor.scale) + return res -Specifically for ``AffineQuantizedTensor``, we also allow people to extend the quantized linear to use a new efficient kernel or implement by defining two functions: -``dispatch_condition`` (defines the condition to dispatch to the kernel) and impl (actual implementation that takes activation, (quantized) weight, bias Tensor and runs the efficient kernel), both taking ``input_tensor``, ``weight_tensor``, ``bias`` as argument, and can be registered into dispatch of quantized linear in ``AffineQuantizedTensor`` with ``register_aqt_quantized_linear_dispatch``. `Here `__ is an example showing how it works. +KernelPreference +################ -Layout/TensorImpl -~~~~~~~~~~~~~~~~~ +For some tensor subclasses, there could be multiple kernel choices for quantize and mm etc. The recommended way to handle this in torchao tensor subclasses is through ``KernelPreference``, that represents which group of kernels we want to use for quantize, mm, group_mm etc. We can use use ``KernelPreference.AUTO`` as default option, as the option for developers to choose whatever we think is the fastest under different conditions for user, so user don't need to worry about the details, and we can have other more specific kernel options for debugging purposes. + +``Float8Tensor`` for example, has: + +* ``KernelPreference.AUTO`` that will choose the most performant quantize and mm kernel based on hardware (H100 SM89 or SM90+), availability of libraries (whether ``fbgemm_gpu_genai`` is installed), granularity (per row or per tensor) +* ``KernelPreference.TORCH`` will use torchao quantize op (``_choose_scale_float8`` and ``_quantize_affine_float8``) and ``_scaled_mm`` +* ``Kerenel.FBGEMM`` uses fbgemm quantize and mm op (``torch.ops.fbgemm.f8f8bf16_rowwise``) -Sometimes the quantized weights has to be packed in order to yield optimal performance. And this can be abstracted with ``layout``. See `here `__ for full example. Flow ~~~~ -After the tensor subclass is implemented, we can also wrap that into factory functions, e.g.:: - # convert from floating point tensor to my dtype tensor subclass - to_my_dtype = MyDTypeTensor.from_float - -For model level API, people can reuse ``torchao.quantize_`` that allows people to apply a tensor subclass conversion to weight of linear, and allows `filtering function `__ to choose which module the tensor subclass conversion should be applied to. +For model level API, people can reuse ``torchao.quantize_`` that allows people to apply a tensor subclass conversion to weight of linear, and allows `filtering function `__ to choose which module the tensor subclass conversion should be applied to. -See Quantization Algorithms/Flows section for examples of weight only/dynamic quant/static quant and other types of model level APIs based on the factory function. +See Quantization Algorithms/Flows section for examples of weight only/dynamic quant and other types of model level APIs. Using torch.compile for Performance ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Note: for pytorch 2.4 and below, we need to use the following:: - from torchao.utils import unwrap_tensor_subclass - m_unwrapped = unwrap_tensor_subclass(m) - In order to be compatible with ``torch.compile``, to aim for performance optimization, we should run through ``torch.compile`` with ``fullgraph=True`` first, and remove any unnecessary graph breaks. You can add ``TORCH_LOGS="output_code"`` when you run the script in order to see the inductor generated code. e.g. ``TORCH_LOGS="output_code" python example.py``:: + model = torch.compile(model, mode="max-autotune", fullgraph=True) Serialization ~~~~~~~~~~~~~ -Please checkout the `serialization doc `__ for more details. +To enable support for serialization (torch.save and torch.load with tensor subclasses as weights), we need to add the tensor subclass and the relevant object to safe globals (available after torch 2.5), e.g.:: + torch.serialization.add_safe_globals([Float8Tensor, QuantizeTensorToFloat8Kwargs]) -.. note:: - We are integrated with huggingface transformer and supports serialization/deserialization through the huggingface save_pretrained/push_to_hub/from_pretrained APIs: https://huggingface.co/docs/transformers/main/en/quantization/torchao +Please checkout the `serialization doc `__ for more details. .. note:: - Another example can be found in integration with diffuser: https://github.com/sayakpaul/diffusers-torchao/blob/main/inference/serialization_and_loading.md + We are `integrated `__ with huggingface transformer and supports serialization and deserialization through the huggingface ``save_pretrained``, ``push_to_hub`` and ``from_pretrained`` APIs. We also have `serialization examples `__ with diffuser models. Other Feature Support @@ -85,8 +112,6 @@ The above just talks about basic feature support, we also provide examples on ho * `Quantized Training `__ * `Tensor Parallel Support for Quantized Tensor `__ * `Compatibility with executorch / torchchat `__ -* [TODO] FSDP -* [TODO] QAT Tensor Subclass Functionality/Composability Testing @@ -126,11 +151,16 @@ After you have the quantization flow implemented, you can run benchmark and eval Note: llama model (llama2/llama3) is our representative model for memory bound models and sam is our representative model for compute bound models. * `llama `__ + * `benchmark `__ * `eval `__ + * `sam `__ + * `benchmark and eval `__ Please checkout the ``--help`` option for each of the script to understand the supported options, e.g. you can use ``--profile=profile_path`` to get the chrome trace of the run to understand detailed `chrome trace `__. Please let us know if there are any new important models that makes sense to be added to torchao model benchmark/eval folder. + +Please also check out `Benchmarking User Guide `__ and `Benchmarking API Guide `__ to understand how to use our benchmarking framework. diff --git a/docs/source/index.rst b/docs/source/index.rst index 7e376432a3..d05f2bd60a 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -18,9 +18,9 @@ for an overall introduction to the library and recent highlight and updates. :maxdepth: 1 :caption: Developer Notes - quantization - sparsity + quantization_overview contributor_guide + sparsity benchmarking_api_guide benchmarking_user_guide @@ -34,6 +34,7 @@ for an overall introduction to the library and recent highlight and updates. api_ref_qat api_ref_sparsity api_ref_float8 + api_ref_utils .. toctree:: :glob: diff --git a/docs/source/quantization.rst b/docs/source/quantization.rst deleted file mode 100644 index 929bc1d00c..0000000000 --- a/docs/source/quantization.rst +++ /dev/null @@ -1,243 +0,0 @@ -Quantization Overview ---------------------- - -First we want to lay out the torchao stack:: - - Quantization Algorithms/Flows: weight only/dynamic/static quantization, hqq, awq, gptq etc. - --------------------------------------------------------------------------------------------- - Quantized Tensors (derived dtypes): AffineQuantizedTensor, CodebookQuantizedTensor - --------------------------------------------------------------------------------------------- - Quantization Primitive Ops/Efficient Kernels: matmul, quantize, dequantize - --------------------------------------------------------------------------------------------- - Basic dtypes: uint1-uint7, int1-int8, float3-float8 - - -Any quantization algorithm will be using some components from the above stack, for example int4 weight-only quantization uses: -(1) weight only quantization flow -(2) `tinygemm bf16 activation + int4 weight kernel `__ and `quant primitive ops `__ -(3) `AffineQuantizedTensor `__ tensor subclass with `TensorCoreTiledLayout `__ -(4) torch.uint4 dtype (simulated with quant_min/quant_max right now) - -Note: we'll also talk about how to compose sparsity with quantization in the Quantized Tensors section - -Basic DTypes -~~~~~~~~~~~~ -`dtype `__ is a bit of overloaded term, by basic dtype, we mean the dtypes that makes sense without any extra metadata (e.g. makes sense when people call ``torch.empty(.., dtype)``), for more details please check out: dev-discuss.pytorch.org/t/supporting-new-dtypes-in-pytorch/1833 - -No matter what quantization we are doing, in the end we will be using some low precision dtypes to represent the quantized data, the dtypes we aim to support in torchao are: - -* ``torch.uint1`` to ``torch.uint8`` available in pytorch 2.3 and later -* ``torch.int1`` to ``torch.int8`` available in pytorch 2.6 and later -* ``torch.float3_e2_m0``, ``torch.float4_e2_m1``, ``torch.float4_e3_m0``, ``torch.float5_e2_m2``, ``torch.float5_e3_m1``, ``torch.float6_e2_m3``, ``torch.float6_e3_m2``, ``torch.float8_e4m3fn``, ``torch.float8_e5m2``, ``torch.float8_e4m3fnuz``, ``torch.float8_e5m2fnuz`` (float8 is added to torch, we also plan to add float4 and float6 to torch if they become popular) - -Note some of the above are prototype only for now. We'll consider adding then to pytorch core when they become popular and have hardware support. - -Current Support -############### -In terms of actual implementation, there are two parts: -1). In PyTorch, we need to add the dtype to torch.dtype, e.g. torch.uint2, example: pytorch/pytorch#117208, but these are just placeholders so that we can use torch.uint2. -2). Outside of PyTorch (e.g. in torchao), we implement the tensor operations for these dtypes with tensor subclasses, also a standard packing format is needed. - -Adding placeholder dtype in PyTorch -*********************************** - -As mentioned in dev-discuss.pytorch.org/t/supporting-new-dtypes-in-pytorch/1833, the criteria for adding dtype in PyTorch is that it shows wide adoption. For the above mentioned fundamental dtypes, the ones that are supported in PyTorch are: - -* ``torch.uint1`` to ``torch.uint8``, ``torch.int1`` to ``torch.int8``, ``torch.float8_e4m3fn``, ``torch.float8_e5m2``, ``torch.float8_e4m3fnuz``, ``torch.float8_e5m2fnuz`` - -For the other types we plan to wait until there is more evidence of wide adoption and hardware support. - -Implementing tensor operations for these dtypes with Tensor subclasses -********************************************************************** -For this, the requirement is we decide on a "standard" packing format, and hopefully one that is amenable to efficient implementation, but for both uintx and floatx we haven't integrate enough kernels to decide on this. So current `packing implementations `__ are ont final. We can revisit after there are more uintx, intx and floatx kernels being integrated into torchao. - -Integrate Tensor subclass to pytorch native factory functions -************************************************************* -After that we can connect the factory function with the tensor subclass, for example: ``torch.empty(..., dtype=torch.int4, ...)`` can create a ``Int4Tensor`` tensor subclass with the packing format decided in the previous step. - -Quantization Primitive Ops -~~~~~~~~~~~~~~~~~~~~~~~~~~ -Quantization primitive ops means the operators used to convert between low preicison quantized tensors and high precision tensors. We will mainly have the following quantization primitive operators: -choose_qparams ops: that chooses quantization parameter based on the original Tensor, typically used in dynamic quantization, e.g. scale and zero_point for affine quantization -quantize op: quantizes the original high precision tensor to the low precision tensor with the dtypes mentioned in previous section based on the quantization parameters -dequantize op: dequantizes the low precision tensor into the high precision tensor based on quantization parameters - -There could be variations of the above to accommodate specific use cases, for example for static quantization we may have ``choose_qparams_affine_with_min_max`` that will choose quantization parameters based on min/max values derived from the observation process. - -Efficient kernels -~~~~~~~~~~~~~~~~~ -We'll also have efficient kernels that works with the low precision tensors, for example - -`_weight_int4pack_mm `__ the tinygemm int4 kernel (bf16 activation + int4 weight) -`int_matmul `__ that takes two int8 tensors and outputs an int32 tensor -`int_scaled_matmul `__ that does matmul and also applies a scale to the result. - -Note: We can also rely on torch.compile to generate kernels (through triton), for example the current int8 weight only quantization `kernel `__ just relies on torch.compile to get speedup. In this case there is no specific "efficient kernel" that's corresponding to the type of quantization. - -Quantized Tensors (derived dtypes) -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -On top of the basic dtypes, quantization primitive operators and efficient kernels, we can glue everything together and build out a Quantized (low precision) Tensor by subclassing torch.Tensor that can be constructed from a high precision Tensor and some parameters that can configure the specific quantization user wants, we can also call this derived dtypes since it can be represented with Tensors of basic dtypes and some extra metadata like scale. - -Existing example in torchao is ``AffineQuantizedTensor``, meaning the low precision Tensor is quantized from the high precision Tensor by an affine mapping, that is: ``low_precision_val = high_precision_val / scale + zero_point``, where ``scale``/``zero_point`` are the quantization parameters that can be calculated by quantization primitive ops or through some optimization procedure. Affine quantization is a very common type of quantization, since it's straightforward that when we try to map from higher precision values to lower precision values, we do an affine transformation (``high_preicsion_val / scale + zero_point``). Another common type of quantization, especially for lower bitwidths (e.g. lower than 4 bit) is codebook / look up table based quantization. - -Layout and TensorImpl -##################### -Native tensors have a hardcoded list of selections of `layout `__, most common one is strided layout, it provides a strided, multi-dimensional view of storage, we also have some sparse and mkldnn layout. - -Take `sparse COO tensor `__ as an example, it has `torch.sparse_coo` layout, and `SparseTensorImpl `__ which changes how the tensor is stored. - -The idea of packing the tensor into different formats fits nicely with the layout concept, that’s why we want to reuse this for packing. We can use `Layout` for different type of packing format and `TensorImpl` for different storage format implementations. And new TensorImpl that stores the Tensor in a packed format can be added at python level tensor subclasses without modifying C++ pytorch core code. - -For example, for ``_weight_int4pack_mm`` we need to pack the weight to an format that is friendly for Tensor Core, we call it `TensorCoreTiledLayout `__. We add a ``tensor_impl`` for the quantized tensor to store the packed (or unpacked) weight, and we use ``layout`` to store different parameters that's relevant for packing:: - - class AffineQuantizedTensor(...): - # tensor_impl is also implemented with tensor subclass - tensor_impl: torch.Tensor - - # to not conflict with existing layout property, we use `_layout` - @property - def _layout(self) -> Layout: - return self.tensor_impl._layout - -Note that layout is an abstraction not only for custom data representation, it is also used for how the -`TensorImpl` interacts with different operators, e.g. the same data representation can have different -implementations when running the same operator, e.g. transpose, quantized_linear, but the operator semantics should stay the same. - -Quantize + Sparse Tensor can also be supported through the Layout abstraction, for example, `int4 weight only quantization + sparse `__. We also provide some common utils that helps people to add different layouts to a quantized tensor, please check out the developer guide below for code examples. - -Quantization Algorithms/Flows -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -On the top of the stack will be the final quantization algorithms and quantization flows. Traditionally we have weight only quantization, dynamic quantization and static quantization, but now we are also seeing more types of quantization coming up. - -For demonstration purposes, let's say after previous step we have ``AffineQuantizedTensor`` and ``to_affine_quantized`` factory function defined. For simplicity, let's say ``to_affine_quantized`` takes a high precision floating point Tensor and a target_dtype (e.g. torch.int8) and converts it to an ``AffineQuantizedTensor`` with corresponding dtype. - -Note: below are all for explaining the concepts, more detailed introduction for utils and examples we provide can be found in ``Tensor Subclass Developer Guide`` section. - -Weight Only Quantization -######################## -This is the simplest form of quantization and it's easy to apply weight only quantization to the model, especially since we have Quantized Tensor. all we need to do is:: - linear_module.weight = torch.nn.Parameter(to_affine_quantized_intx(linear_module.weight, ...), requires_grad=False)) - -apply the above to all linear modules in the model and we'll get a weight only quantized model. - -Dynamic Activation and Weight Quantization -########################################## - -This is called "dynamic quantization" before but it means we quantize activation dynamically at runtime, and also quantize the weights as well. Compared to the weight only quantization, the main question is how do we apply the quantization to activation. In torchao, the common pattern we use is by applying ``to_linear_activation_quantized`` on top of quantized weight:: - quantized_weight = to_affine_quantized(linear_module.weight) - activation_and_weight_quantized = to_linear_activation_quantized(quantized_weight) - linear_module.weight = torch.nn.Parameter(activation_and_weight_quantized, requires_grad=False)) - -``to_linear_activation_quantized`` is used to apply quantization to activation, it takes a ``input_quant_func`` that will quantize the activation and the original weight, and during runtime when it encounters a ``F.linear`` op, it will apply the stored input_qunat_func to activation and redispatch to ``F.linear`` with quantized activation and weight. - -If the above does not work, user can also do module swaps, or use ``torch.fx.symbolic_trace()`` to get a traced module that you can `modify `__. - -But using tensor subclass is preferred because it is easier for serialization/deserialization, if we use tensor subclasses to support dynamic quantization, then we can load the quantized weights directly without further preparation for the model. Otherwise, we'd need to do module swap or other modifications to the model first before loading the quantized weights. - -Static Activation Quantization and Weight Quantization -###################################################### -Static quantization means activation is statically quantized instead of dynamically quantized at runtime. In terms of flow, static quantization requires calibration with sample data in order that we can figure out the appropriate quantization parameters. - -At the high level there are three steps for static quantization: (1) insert observers (2) calibration (3) quantize the model - - -Insert Observers -**************** -In insert observers step, we need to add observer modules to input (and output) activation and weight of the operator to collect statistics of the Tensor. So there are two things we need to address, how to define observer module? how to add observer module to the model. - -How to define observer module -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Observers are specific to: (1) type of quantization (e.g. affine quantization, look up table based quantization) (2) type of stats we want to track, e.g. min max observer, moving average observer. - -Generally an observer module should define `forward `__ and `calculate_qparams `__ - -For affine quantization, we defined `AffineQuantizedMinMaxObserver `__ that records min_val/max_val based on the granularity of affine quantization, and also defines how to calculate_qparams based on the recorded stats. - -How to add observer module to the model -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -1. Use Tensor Subclasses - If the only operator you are interested in quantizing is linear, you can use `linear activation weight observer `__, we also have a corresponding `insert_observer_ `__ API that handles modifying the weight of linear. - -2. Module Swap - Alternatively, you could also define and `ObservedLinear `__ module (or other module types) and swap the non observed with the observed module - -Calibration -^^^^^^^^^^^ -Calibration step is typically straightforward, typically we just need to run the model through the calibration dataset. For more complicated calibration (e.g. where we record all inputs and do optimizations based on all inputs), we'll cover some of them in next section. - -Quantize -^^^^^^^^ -We can reuse the ``quantize_`` API but provide a different ``apply_tensor_subclass`` function that converts the observed linear module to a linear module with quantized weight and statically quantized input activation, this can be done in the same manner as the dynamic quantization (with ``to_linear_activation_quantized``), see `example `__. - -Alternatively, user can do `module swap `__ as well. - -Other Quantization Flows -######################## - -For other quantization flow/algorithms that does not fit into any of the above, we also intend to provide examples for common patterns. For example, `GPTQ like quantization flow `__ that is adopted by `Autoround `__, it uses `MultiTensor `__ and module hooks to optimize the module. - -If you are working on a new quantization algorithm/flow and not sure how to implement it in a PyTorch native way, please feel free to open an issue to describe how your algorithm works and we can help advise on the implementation details. - -Training -######## -The above flow are mainly focused on inference, but low bit dtype Tensors can be used in training as well. - -Quantization Aware Training -*************************** -TODO - - -Low Bit Optimizers -****************** -Today we have some prototype low bit optimizers: `main/torchao/prototype/low_bit_optim `__ that implements a specific type of 4 bit, 8 bit and float8, and is also composable with FSDP (with look up table quantization). - -Quantized Training -****************** -Similar to low bit optimizers, we have quantized training prototype in `main/torchao/prototype/quantized_training `__, and we could extend AffineQuantizedTensor to support training as well, initial enablement is in progress, but there will be a lot of follow up work needed including making it work for different kernels etc. - -You can also checkout the tutorial for `Quantized Training `__ that talks about how to make a dtype tensor subclass trainable. - -Case Study: How int4 weight only quantization works in torchao? -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -To connect everything together, here is a more detailed walk through for how int4 weight only quantization is implemented in torchao. - -Quantization Flow: quantize_(model, Int4WeightOnlyConfig()) - * What happens: linear.weight = torch.nn.Parameter(to_affine_quantized_intx(linear.weight), requires_grad=False) - * quantization primitive ops: choose_qparams and quantize_affine are called to quantize the Tensor - * quantized Tensor will be `AffineQuantizedTensor`, a quantized tensor with derived dtype (e.g. int4 with scale and zero_point) - * packing op `_convert_weight_to_int4pack` to pack the quantized weight for efficient execution - -During Model Execution: model(input) - * `torch.ops.aten._weight_int4pack_mm` is called on input and the packed weight - -During Quantization -################### -First we start with the API call: ``quantize_(model, Int4WeightOnlyConfig())`` what this does is it converts the weights of nn.Linear modules in the model to int4 quantized tensor (``AffineQuantizedTensor`` that is int4 dtype, asymmetric, per group quantized), using the layout for tinygemm kernel: ``tensor_core_tiled`` layout. - -* `quantize_ `__: the model level API that quantizes the weight of linear by applying the conversion function from user (second argument) -* `Int4WeightOnlyConfig `__: the function that returns a function that converts weight of linear to int4 weight only quantized weight - * Calls quantization primitives ops like choose_qparams_affine and quantize_affine to quantize the Tensor -* `TensorCoreTiledLayout `__: the tensor core tiled layout type, storing parameters for the packing format -* `TensorCoreTiledAQTTensorImpl `__: the tensor core tiled TensorImpl, stores the packed weight for efficient int4 weight only kernel (tinygemm kernel) - -During Model Execution -###################### - -When we run the quantized model ``model(inputs)``, we'll run through the functional linear operator in nn.Linear:: - - return F.linear(input, weight, bias) - -where input is a ``bfloat16`` Tensor, weight is an int4 ``AffineQuantizedTensor``, it calls into a ``__torch_function__`` of the ``AffineQuantizedTensor`` subclass, which will end up in an implementation for ``F.linear`` when one of the input is ``AffineQuantizedTensor``, so it calls:: - return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) - -The ``_quantized_linear_op`` goes through the ``_AQT_QLINEAR_DISPATCH_TABLE`` and checks each dispatch conditions, if the dispatch condition passes, it will call the implementation with ``input``/``weight``/``bias``. Please check out `this doc `__ for the explanation of ``dispatch_condition`` and ``impl``. - -int4 weight only `dispatch_condition `__ checks if the input is ``bfloat16`` Tensor and weight is a uint4 ``AffineQuantizedTensor`` -wint4 weight only quantization `kernel implementation `__ takes an bfloat16 input Tensor and an int4 AffineQuantizedTensor, and call ``torch.ops.aten._weight_int4pack_mm`` with the input Tensor and the packed weight that's stored in ``weight_tensor.tensor_impl``. - -During Save/Load -################ - -Since ``AffineQuantizedTensor`` weight is still a ``torch.Tensor``, save/load works the same way as the original high precision floating point model. See the `serialization doc `__ for more details. - - diff --git a/docs/source/quantization_overview.rst b/docs/source/quantization_overview.rst new file mode 100644 index 0000000000..bdb675eff8 --- /dev/null +++ b/docs/source/quantization_overview.rst @@ -0,0 +1,230 @@ +Quantization Overview +--------------------- + +First we want to lay out the torchao stack:: + + Quantization Algorithms/Flows: weight only/dynamic/static quantization, hqq, awq, gptq etc. + --------------------------------------------------------------------------------------------- + Quantized Tensors (derived dtypes): Int4Tensor, Int4PreshuffledTensor, Float8Tensor + --------------------------------------------------------------------------------------------- + Quantization Primitive Ops/Efficient Kernels: matmul, quantize, dequantize + --------------------------------------------------------------------------------------------- + Basic dtypes: uint1-uint7, int1-int8, float3-float8 + + +Any quantization algorithm will be using some components from the above stack, for example per row float8 dynamic activation and float8 weight quantization (with default preference) uses: + +* dynamic quantization flow +* `Float8Tensor `__ +* `float8 activation + float8 weight fbgemm kernel `__ and `triton quant primitive ops from fbgemm library `__ +* ``torch.float8_e4m3fn`` dtype + +Basic DTypes +~~~~~~~~~~~~ +`dtype `__ is a bit of overloaded term, by basic dtype, we mean the dtypes that makes sense without any extra metadata (e.g. makes sense when people call ``torch.empty(.., dtype)``), for more details please check out `this post `__. + +No matter what quantization we are doing, in the end we will be using some low precision dtypes to represent the quantized data or quantization parameters, the low precision dtypes relevant for torchao are: + +* ``torch.uint1`` to ``torch.uint7`` available in pytorch 2.3 and later +* ``torch.int1`` to ``torch.int7`` available in pytorch 2.6 and later +* ``torch.float4_e2m1fn_x2``, ``torch.float8_e4m3fn``, ``torch.float8_e4m3fnuz``, ``torch.float8_e5m2``, ``torch.float8_e5m2fnuz``, ``torch.float8_e8m0fnu`` + +In terms of actual implementation, ``uint1`` to ``uint7`` and ``int1`` to ``int7`` are just placeholders that does not have real implementations (i.e. the ops does not work for the PyTorch Tensor with these dtypes). Example PR added these dtypes can be found `here `__. Floating point dtypes are what we call shell dtypes that have limited op support. + +For more details please check out the `official PyTorch dtype doc `__. + +.. note:: + Dervied dtypes like mxfp8, mxfp4, nvfp4 are implemented with these basic dtypes, e.g. mxfp4 uses ``torch.float8_e8m0fnu`` for scale and ``torch.float4_e2m1fn_x2`` for 4 bit data. + +Quantization Primitive Ops +~~~~~~~~~~~~~~~~~~~~~~~~~~ +Quantization primitive ops means the operators used to convert between low preicison quantized tensors and high precision tensors. We will mainly have the following quantization primitive operators: + +* choose_qparams ops: that chooses quantization parameter based on the original Tensor, typically used in dynamic quantization, e.g. scale and zero_point for affine quantization +* quantize op: quantizes the original high precision tensor to the low precision tensor with the dtypes mentioned in previous section based on the quantization parameters +* dequantize op: dequantizes the low precision tensor into the high precision tensor based on quantization parameters + +There could be variations of the above to accommodate specific use cases, for example for static quantization we may have ``choose_qparams_affine_with_min_max`` that will choose quantization parameters based on min/max values derived from the observation process. + +There could be multiple versions of the op that is different by different kernel libraries that we can use in torchao, for example, for quantizing a bfloat16 Tensor to a raw float8 Tensor and scale: `_choose_scale_float8 `__ and `_quantize_affine_float8 `__ for torchao implementation, and `torch.ops.triton.quantize_fp8_row `__ from fbgemm library. + +Efficient kernels +~~~~~~~~~~~~~~~~~ +We'll also have efficient kernels that works with the low precision tensors, for example: + +* `torch.ops.fbgemm.f8f8bf16_rowwise `__ (rowwise float8 activation and float8 weight matrix multiplication kernel in fbgemm library) +* `torch._scaled_mm `__ (float8 activation and float8 weight matrix multiplication kernel in PyTorch for both rowwise and tensorwise) +* `int_matmul `__ that takes two int8 tensors and outputs an int32 tensor +* `int_scaled_matmul `__ that does matmul and also applies a scale to the result. + +.. note:: + We can also rely on torch.compile to generate kernels (through triton), for example the current int8 weight only quantization `kernel `__ just relies on torch.compile to get speedup. In this case there is no custom handwritten "efficient kernel" that's corresponding to the type of quantization. + +Quantized Tensors (derived dtypes and packing format) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +On top of the basic dtypes, quantization primitive operators and efficient kernels, we can glue everything together and build out a Quantized (low precision) Tensor by subclassing torch.Tensor that can be constructed from a high precision Tensor and some parameters that can configure the specific quantization user wants, we can also call this derived dtypes since it can be represented with Tensors of basic dtypes and some extra metadata like scale. + +Another dimension for quantized Tensor is packing format, meaning how the quantized raw data is laid out in memory. For example, for int4, we can pack two elements together side by side in a uint8 value, or people can do some preshuffling/swizzling to make the format more efficient for memory operations (loading from memory to register) and computation. + +So in general we structure Tensor subclasses by dervied dtpype and packing format: + +.. list-table:: Tensor Subclasses in TorchAO + :widths: 20 10 30 40 + :header-rows: 1 + + * - Tensor + - Derived Dtype + - Packing Format + - Support + * - Float8Tensor + - scaled float8 + - plain (no packing needed) + - float8 act + float8 weight dynamic quantization and float8 weight only quantization + * - Int4Tensor + - scaled int4 + - plain (pack 2 adjacent int4 to a single int8 value) + - int4 weight only quantization + * - Int4PreshuffledTensor + - scaled int4 + - preshuffled (special format to optimize for loading) + - float8 act + int4 weight dynamic quantization and int4 weight only quantization + +.. note:: + We don't have granularity specific tensor subclasses, i.e. no Float8RowwiseTensor or Float8BlockwiseTensor, all granularities are implemented in the same Tensor, we typically use a general `block_size` attribute to distinguish between different granularities, and each Tensor is allowed to support only a subset of all possible granularity options. + +.. note:: + We also don't use dynamic activation in the name, since we are talking about the weight tensor object, including information about activation in the tensor subclass name will be confusing, but + we do implement both weight only and dynamic activation quantization in the same linear function implementation, without relying on additional abstractions, this keeps relevant quantization operations close + to each other (quantization of activation and weight) in the same tensor subclass. + +In terms of how we quantize a Tensor, most of Tensors are using affine quantization, meaning the low precision Tensor is quantized from the high precision Tensor by an affine mapping, that is: ``low_precision_val = high_precision_val / scale + zero_point``, where ``scale`` and ``zero_point`` are the quantization parameters that can be calculated by quantization primitive ops or through some optimization procedure. Another common type of quantization, especially for lower bitwidths (e.g. lower than 4 bit) is codebook / look up table based quantization where the raw quantized data is the index we can use to look up a ``codebook`` that stores the values or vectors each index corresponds to. A common way to get the codebook and the raw quantized data for codebook quantization is kmeans clustering. + +Quantization Algorithms/Flows +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +On the top of the stack will be the final quantization algorithms and quantization flows. Traditionally we have weight only quantization, dynamic quantization and static quantization, but now we are also seeing more types of quantization coming up. + +For demonstration purposes, let's say after previous step we have ``Float8Tensor`` defined. ``Float8Tensor.from_hp`` takes a high precision floating point Tensor and a target_dtype (e.g ``torch.float8_e4m3fn``) and converts it to a ``Float8Tensor`` + +Note: below are all for explaining the concepts, more detailed introduction for utils and examples we provide can be found in `Contributor Guide `__. + +Weight Only Quantization +######################## +This is the simplest form of quantization and it's easy to apply weight only quantization to the model, especially since we have Quantized Tensor. all we need to do is:: + + linear_module.weight = torch.nn.Parameter(Float8Tensor.from_hp(linear_module.weight, ...), requires_grad=False)) + +apply the above to all linear modules in the model and we'll get a weight only quantized model. + +Dynamic Activation and Weight Quantization +########################################## + +This is called "dynamic quantization" before but it means we quantize activation dynamically at runtime, and also quantize the weights as well. Compared to the weight only quantization, the main question is how do we apply the quantization to activation. In torchao we pass around the quantization keyword args for activation and the keyword args will be applied to activation when needed (e.g. in linear):: + + activation_dtype = torch.float8_e4m3fn + activation_granularity = PerRow() + # define kwargs for float8 activation quantization + act_quant_kwargs = QuantizeTensorToFloat8Kwargs( + activation_dtype, + activation_granularity, + ) + weight_dtype = torch.float8_e4m3fn + weight_granularity = PerRow() + quantized_weight = Float8Tensor.from_hp(linear_module.weight, float8_dtype=weight_dtype, granularity=weight_granularity, act_quant_kwargs=act_quant_kwargs) + linear_module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)) + +Static Activation Quantization and Weight Quantization +###################################################### +We'll skip the instruction for now since we haven't seen many use cases for static quantization with tensor subclass based flow, we recommend to look into the `PT2 export quantization flow `__ for static quantization. + +Other Quantization Flows +######################## + +For other quantization flow/algorithms that does not fit into any of the above, we also intend to provide examples for common patterns. For example, `GPTQ like quantization flow `__ that is adopted by `Autoround `__, it uses `MultiTensor `__ and module hooks to optimize the module. + +If you are working on a new quantization algorithm/flow and not sure how to implement it in a PyTorch native way, please feel free to open an issue to describe how your algorithm works and we can help advise on the implementation details. + +Training +######## +The above flow are mainly focused on inference, but low bit dtype Tensors can be used in training as well. + +User facing docs for float8 training can be found `here `__ and docs for finetuning can be found `here `__ + +Quantization Aware Training +*************************** +TorchAO supports `quantization aware training `__ through the `quantize_` API as well. + + +Low Bit Optimizers +****************** +We support `low bit optimizers `__ that implements a specific type of 4 bit, 8 bit and float8, and is also composable with FSDP (with look up table quantization). + +Quantized Training +****************** +We have quantized training prototype in `main/torchao/prototype/quantized_training `__, and we could extend existing tensor subclasses to support training as well, initial enablement is in progress, but there will be a lot of follow up work needed including making it work for different kernels etc. + +You can also checkout the tutorial for `Quantized Training `__ that talks about how to make a dtype tensor subclass trainable. + +Case Study: How float8 dynamic activation and float8 weight quantization works in torchao? +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +To connect everything together, here is a more detailed walk through for float8 dynamic activation and float8 weight quantization in torchao (DEFAULT kernel preference, in H100, when fbgemm_gpu_genai library is installed): + +Quantization Flow: ``quantize_(model, Float8DynamicActivationFloat8WeightConfig())`` + * What happens: ``linear.weight = torch.nn.Parameter(Float8Tensor.from_hp(linear.weight), requires_grad=False)`` + * quantization primitive ops: ``torch.ops.triton.quantize_fp8_row`` + * quantized Tensor will be ``Float8Tensor``, a quantized tensor with derived dtype of scaled float8 + +During Model Execution: model(input) + * ``torch.ops.fbgemm.f8f8bf16_rowwise`` is called on input, raw float8 weight and scale + +During Quantization +################### +First we start with the API call: ``quantize_(model, Float8DynamicActivationFloat8WeightConfig())`` what this does is it converts the weights of nn.Linear modules in the model to ``Float8Tensor``, with plain packing format, no packing is required, since we have ``torch.float8_e4m3fn`` that can represent quantized float8 raw data directly without additional operations. + +* `quantize_ `__: the model level API that quantizes the weight of linear by applying the config from user (second argument) +* `Float8DynamicActivationFloat8WeightConfig `__: the config for float8 dynamic activation and float8 weight quantization + * Calls quantization primitives ops ``torch.ops.triton.quantize_fp8_row`` to quantize a bfloat16 Tensor to float8 raw Tensor and get a scale + + +During Model Execution +###################### + +When we run the quantized model ``model(inputs)``, we'll run through the functional linear operator in nn.Linear:: + + return F.linear(input, weight, bias) + +where input is a ``bfloat16`` Tensor, weight is a ``Float8Tensor``, it calls into a ``__torch_function__`` of the ``Float8Tensor`` subclass, which will end up in an implementation for ``F.linear`` when one of the `input `__ is ``Float8Tensor``:: + + @implements([torch.nn.functional.linear, aten.linear.default]) + def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + # quantizing activation, if `act_quant_kwargs` is specified + if act_quant_kwargs is not None: + input_tensor = _choose_quant_func_and_quantize_tensor( + input_tensor, act_quant_kwargs + ) + + # omitting kernel_preference related code + # granularity checks, let's say we are doing rowwise quant + # both input_tensor and weight_tensor will now be Float8Tensor + xq = input_tensor.qdata.reshape(-1, input_tensor.qdata.shape[-1]) + wq = weight_tensor.qdata.contiguous() + x_scale = input_tensor.scale + w_scale = weight_tensor.scale + res = torch.ops.fbgemm.f8f8bf16_rowwise( + xq, + wq, + x_scale, + w_scale, + ).reshape(out_shape) + return res + +The function first quantizes the input to be ``Float8Tensor``, then get the raw float Tensor and scale from both the input and weight Tensor: ``t.qdata``, ``t.scale``, and calls the fbgemm kernel to do the matrix multiplication for float8 dynamic quantization: ``torch.ops.fbgemm.f8f8bf16_rowwise``. + +During Save/Load +################ + +Since ``Float8Tensor`` weight is still a ``torch.Tensor``, save/load works the same way as the original high precision floating point model. See the `serialization doc `__ for more details. diff --git a/docs/source/quick_start.rst b/docs/source/quick_start.rst index 02b59c2430..c2e7a542df 100644 --- a/docs/source/quick_start.rst +++ b/docs/source/quick_start.rst @@ -210,7 +210,7 @@ In this quick start guide, we learned how to quantize a simple model with torchao. To learn more about the different workflows supported in torchao, see our main `README `__. For a more detailed overview of quantization in torchao, visit -`this page `__. +`this page `__. Finally, if you would like to contribute to torchao, don't forget to check out our `contributor guide `__ and our list of diff --git a/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py b/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py index 443ddea00e..b07e509a79 100644 --- a/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py +++ b/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py @@ -35,7 +35,7 @@ def _choose_quant_func_and_quantize_tensor( ) -> torch.Tensor: """Given a tensor and a kwargs container, chooses a derived dtype (float8, int8, etc) to quantize tensor to, based on the type of quant_kwargs quantizes tensor to the derived dtype chosen in (1) - This is needed to support flexible quantization of activation and weights to various derived dtypes. + This is needed to support flexible quantization of activation to various derived dtypes. """ from torchao.quantization.quantize_.workflows import ( Float8Tensor,