Skip to content

Commit 4bde9f1

Browse files
committed
Update quantization overview and contributor guide doc
Summary: We have recently updated our design for structuring tensor subclasses in torchao to remove unnecessary abstractions and reduce indirections and having a structuring that aligns better with people's intuitive understanding of different quantization use cases, examples using the new design are: pytorch#2463, pytorch#2687 Test Plan: check generated doc Reviewers: Subscribers: Tasks: Tags:
1 parent 6cfa477 commit 4bde9f1

File tree

7 files changed

+310
-270
lines changed

7 files changed

+310
-270
lines changed

docs/source/api_ref_utils.rst

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
.. _api_utils:
2+
3+
4+
=============
5+
torchao.utils
6+
=============
7+
8+
.. currentmodule:: torchao.utils
9+
10+
Tensor Subclass Utils
11+
---------------------
12+
.. autosummary::
13+
:toctree: generated/
14+
:nosignatures:
15+
16+
TorchAOBaseTensor
17+
18+
=====================================
19+
torchao.quantization.quantize_.common
20+
=====================================
21+
22+
.. currentmodule:: torchao.quantization.quantize_.common
23+
24+
quantize_ API Common Utils
25+
--------------------------
26+
.. autosummary::
27+
:toctree: generated/
28+
:nosignatures:
29+
30+
KernelPreference
31+
PackingFormat
32+
QuantizeTensorKwargs
33+
_choose_quant_func_and_quantize_tensor

docs/source/contributor_guide.rst

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,33 @@ Contributor Guide
44
General Guide on Extending torchao
55
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
66

7-
For a new use case, for example, a training dtype (like fp4 training), it's fine to start with adding a new tensor subclass in prototype folder `torchao/prototype <https://github.com/pytorch/ao/tree/main/torchao/prototype>`__, but you could also take a look at ``AffineQuantizedTensor`` if what you want to do is mostly supported there, e.g. adding int3 kernel for the exact same affine quantization. Please feel free to open an issue and if you have questions on what to do for a specific new use case. For more details, please refer to our `quantization overview page <quantization.html>`__.
7+
Please start by reading our `quantization overview page <quantization_overview.html>`__ first.
88

99
To contribute to existing code base:
1010

11-
* Adding features to AffineQuantizedTensor, e.g. making it trainable, add tensor parallelism support etc.: `torchao/dtypes/affine_quantized_tensor.py <https://github.com/pytorch/ao/blob/main/torchao/dtypes/affine_quantized_tensor.py>`__
11+
* Adding features to existing Tensor subclasses like ``Float8Tensor``, e.g. adding new operator support, making it trainable, add tensor parallelism support etc., `tensor subclasses <https://github.com/pytorch/ao/tree/main/torchao/quantization/quantize_/workflows>`__, `tests <https://github.com/pytorch/ao/tree/main/test/quantization/quantize_/workflows>`__
1212
* Adding new quantization APIs: `torchao/quantization/quant_api.py <https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_api.py>`__
1313
* Adding new quantization primitive ops, e.g. slight variations of existing quantization primitive ops: `torchao/quantization/quant_primitives.py <https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py>`__
1414
* Adding new autotuned triton kernels: `torchao/kernel <https://github.com/pytorch/ao/tree/main/torchao/kernel>`__
1515
* Adding new custom cpu/cuda/mps kernels: `torchao/csrc <https://github.com/pytorch/ao/tree/main/torchao/csrc>`__
16-
* Integrating custom kernel with AffineQuantizedTensor (maybe a new layout as well): Add sparse marlin AQT layout `#621 <https://github.com/pytorch/ao/pull/621>`__ as an example. We are still not decided if we want to split ``AffineQuantizedTensor`` to more tensor subclasses or not.
16+
17+
Adding New Tensor Subclasses
18+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
19+
torchao Tensor subclasses are structured by ``derived dtype`` and ``packing format``, please check out the `quantization overview page <quantization_overview.html>`__ to understand these concepts. If a new tensor subclass is needed for your use case, i.e. a new dtype, or a new packing format that does not already exist, we could define a new Tensor.
20+
21+
To understand how to use tensor subclass in the context of quantization, please also check `Writing Your Own Quantized Tensor <https://docs.pytorch.org/ao/main/subclass_basic.html>`__.
22+
23+
We have utility base class: ``torchao.utils.TorchAOBaseTensor`` that can help define common util functions and methods for you, if you specified the names of Tensor and non-Tensor attributes of the tensor subclass. for example::
24+
25+
class MyTensor(TorchAOBaseTensor):
26+
tensor_data_names = ["qdata", "scale"]
27+
tensor_attribute_names = ["device", "dtype"]
28+
29+
30+
With the above, we'll have multiple methods and functions available to use for this Tensor, for more details please check the docs for `TorchAOBaseTensor <https://docs.pytorch.org/ao/main/api_ref_utils.html>`__ (TODO: update to a more specific link)
31+
32+
.. note::
33+
Many of the existing use cases in torchao still uses AffineQuantizedTensor, but we plan to move away from it to reduce the abstractions and make it easier for people to contribute to torchao.
1734

1835
Adding Efficient Kernels
1936
~~~~~~~~~~~~~~~~~~~~~~~~
@@ -31,44 +48,55 @@ Custom hand written kernels
3148
###########################
3249
Custom kernels (implementations) for cpu/cuda/mps can be implemented through `torchao/csrc <https://github.com/pytorch/ao/tree/main/torchao/csrc>`__ e.g. int4 cuda, and accessible through torch.ops.my_custom_op
3350

34-
Dispatches
35-
~~~~~~~~~~
51+
Using hand written kernels in Tensor Subclasses
52+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
53+
54+
For calling optimized kernels, we have ``implements`` from the tensor subclass, for example, if we want to call into a new custom op: ``torch.ops.ao.my_mm_for_mps``::
3655

37-
For dispatching to optimized kernels for cpu/cuda/mps devices, we can have checks for the dispatch conditions in ``__torch_function__`` or ``__torch_dispatch__`` and dispatch to target operators, for example, condition for bfloat16 activation and uint4 weight kernel can be found `here <https://github.com/pytorch/ao/blob/242f181fe59e233b458740b06464ad42da8df6af/torchao/dtypes/affine_quantized_tensor.py#L1784-L1797>`__.
56+
class Float8Tensor(TorchAOBaseTensor):
57+
...
3858

39-
Specifically for ``AffineQuantizedTensor``, we also allow people to extend the quantized linear to use a new efficient kernel or implement by defining two functions:
40-
``dispatch_condition`` (defines the condition to dispatch to the kernel) and impl (actual implementation that takes activation, (quantized) weight, bias Tensor and runs the efficient kernel), both taking ``input_tensor``, ``weight_tensor``, ``bias`` as argument, and can be registered into dispatch of quantized linear in ``AffineQuantizedTensor`` with ``register_aqt_quantized_linear_dispatch``. `Here <https://github.com/pytorch/ao/blob/e283743b3cc4612bb641b88dca3670231724d396/test/dtypes/test_affine_quantized.py#L92-L113>`__ is an example showing how it works.
59+
implements = Float8Tensor.implements
4160

42-
Layout/TensorImpl
43-
~~~~~~~~~~~~~~~~~
61+
@implements([torch.nn.functional.linear, aten.linear.default])
62+
def _(func, types, args, kwargs):
63+
...
64+
# call into the custom op
65+
res = torch.ops.ao.my_mm_for_mps(input_tensor.qdata, weight_tensor.qdata, input_tensor.scale, weight_tensor.scale)
66+
return res
67+
68+
KernelPreference
69+
################
70+
71+
For some tensor subclasses, there could be multiple kernel choices for quantize and mm etc. The recommended way to handle this in torchao tensor subclasses is through ``KernelPreference``, that represents which group of kernels we want to use for quantize, mm, group_mm etc. We can use use ``KernelPreference.AUTO`` as default option, as the option for developers to choose whatever we think is the fastest under different conditions for user, so user don't need to worry about the details, and we can have other more specific kernel options for debugging purposes.
72+
73+
``Float8Tensor`` for example, has:
74+
* ``KernelPreference.AUTO`` that will choose the most performant quantize and mm kernel based on hardware (H100 SM89 or SM90+), availability of libraries (whether fbgemm_gpu_genai is installed), granularity (per row or per tensor)
75+
* ``KernelPreference.TORCH`` will use torchao quantize op (``_choose_scale_float8`` and ``_quantize_affine_float8``) and ``_scaled_mm``
76+
* ``Kerenel.FBGEMM`` uses fbgemm quantize and mm op (``torch.ops.fbgemm.f8f8bf16_rowwise``)
4477

45-
Sometimes the quantized weights has to be packed in order to yield optimal performance. And this can be abstracted with ``layout``. See `here <https://github.com/pytorch/ao/blob/17a0a96d24ebfc154a23342b84e788d9ed6776f4/tutorials/developer_api_guide/my_dtype_tensor_subclass.py#L215-L317>`__ for full example.
4678

4779
Flow
4880
~~~~
4981

50-
After the tensor subclass is implemented, we can also wrap that into factory functions, e.g.::
51-
# convert from floating point tensor to my dtype tensor subclass
52-
to_my_dtype = MyDTypeTensor.from_float
53-
54-
For model level API, people can reuse ``torchao.quantize_`` that allows people to apply a tensor subclass conversion to weight of linear, and allows `filtering function <https://github.com/pytorch/ao/blob/17a0a96d24ebfc154a23342b84e788d9ed6776f4/torchao/quantization/quant_api.py#L421>`__ to choose which module the tensor subclass conversion should be applied to.
82+
For model level API, people can reuse ``torchao.quantize_`` that allows people to apply a tensor subclass conversion to weight of linear, and allows `filtering function <https://docs.pytorch.org/ao/main/generated/torchao.quantization.quantize_.html#torchao.quantization.quantize_>`__ to choose which module the tensor subclass conversion should be applied to.
5583

56-
See Quantization Algorithms/Flows section for examples of weight only/dynamic quant/static quant and other types of model level APIs based on the factory function.
84+
See Quantization Algorithms/Flows section for examples of weight only/dynamic quant and other types of model level APIs.
5785

5886
Using torch.compile for Performance
5987
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
6088

61-
Note: for pytorch 2.4 and below, we need to use the following::
62-
from torchao.utils import unwrap_tensor_subclass
63-
m_unwrapped = unwrap_tensor_subclass(m)
64-
6589
In order to be compatible with ``torch.compile``, to aim for performance optimization, we should run through ``torch.compile`` with ``fullgraph=True`` first, and remove any unnecessary graph breaks. You can add ``TORCH_LOGS="output_code"`` when you run the script in order to see the inductor generated code. e.g. ``TORCH_LOGS="output_code" python example.py``::
90+
6691
model = torch.compile(model, mode="max-autotune", fullgraph=True)
6792

6893
Serialization
6994
~~~~~~~~~~~~~
7095

71-
Please checkout the `serialization doc <https://pytorch.org/ao/stable/serialization.html>`__ for more details.
96+
To enable support for serialization (torch.save and torch.load with tensor subclasses as weights), we need to add the tensor subclass and the relevant object to safe globals (available after torch 2.5), e.g.::
97+
torch.serialization.add_safe_globals([Float8Tensor, QuantizeTensorToFloat8Kwargs])
98+
99+
Please checkout the `serialization doc <serialization.html>`__ for more details.
72100

73101
.. note::
74102
We are integrated with huggingface transformer and supports serialization/deserialization through the huggingface save_pretrained/push_to_hub/from_pretrained APIs: https://huggingface.co/docs/transformers/main/en/quantization/torchao
@@ -85,8 +113,6 @@ The above just talks about basic feature support, we also provide examples on ho
85113
* `Quantized Training <https://github.com/pytorch/ao/blob/main/tutorials/developer_api_guide/my_trainable_tensor_subclass.py>`__
86114
* `Tensor Parallel Support for Quantized Tensor <https://github.com/pytorch/ao/blob/main/tutorials/developer_api_guide/tensor_parallel.py>`__
87115
* `Compatibility with executorch / torchchat <https://github.com/pytorch/ao/blob/main/tutorials/developer_api_guide/export_to_executorch.py>`__
88-
* [TODO] FSDP
89-
* [TODO] QAT
90116

91117

92118
Tensor Subclass Functionality/Composability Testing
@@ -134,3 +160,5 @@ Note: llama model (llama2/llama3) is our representative model for memory bound m
134160
Please checkout the ``--help`` option for each of the script to understand the supported options, e.g. you can use ``--profile=profile_path`` to get the chrome trace of the run to understand detailed `chrome trace <https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html#using-tracing-functionality>`__.
135161

136162
Please let us know if there are any new important models that makes sense to be added to torchao model benchmark/eval folder.
163+
164+
Please also check out `Benchmarking User Guide <https://docs.pytorch.org/ao/main/benchmarking_user_guide.html>`__ and `Benchmarking API Guide <https://docs.pytorch.org/ao/main/benchmarking_api_guide.html>`__ to understand how to use our benchmarking framework.

docs/source/index.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ for an overall introduction to the library and recent highlight and updates.
1818
:maxdepth: 1
1919
:caption: Developer Notes
2020

21-
quantization
21+
quantization_overview
2222
sparsity
2323
contributor_guide
2424
benchmarking_api_guide
@@ -34,6 +34,7 @@ for an overall introduction to the library and recent highlight and updates.
3434
api_ref_qat
3535
api_ref_sparsity
3636
api_ref_float8
37+
api_ref_utils
3738

3839
.. toctree::
3940
:glob:

0 commit comments

Comments
 (0)