Skip to content

Commit 27bf5bf

Browse files
authored
Merge pull request #9 from pytorch-labs/gh/HDCharles/1/base
Documentation Updates
2 parents bf3659d + a312372 commit 27bf5bf

File tree

9 files changed

+173
-87
lines changed

9 files changed

+173
-87
lines changed

README.md

Lines changed: 69 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,29 +29,92 @@ torchao 0.0.1 <install dir>
2929

3030
Relevant APIs can be found in torchao.quantization.quant_api
3131

32+
Note: While these techniques are designed to improve model performance, in some cases the opposite can occur.
33+
This is because quantization adds additional overhead to the model that is hopefully made up for by faster matmuls (dynamic quantization) or loading weights faster (weight-only quantization). If your matmuls are small enough or your non-quantized perf isn't bottlenecked by weight load time, these techniques may reduce performance.
34+
35+
### A8W8 Dynamic Quantization
36+
37+
Similar to the weight only api above, the `apply_dynamic_quant` function swaps all
38+
linear modules to dynamically quantized quantized linear modules.
39+
40+
Example
41+
42+
```
43+
44+
# some user model and example input
45+
...
46+
47+
# convert linear modules to quantized linear modules
48+
quant_api.apply_dynamic_quant(model)
49+
50+
# compile the model to improve performance
51+
...
52+
```
53+
54+
This technique works best when the torch._inductor.config.force_fuse_int_mm_with_mul option is enabled. This allows fusion of the int8*int8 -> int32 matmul and subsequent mul op, thereby avoiding materialization of the int32 intermediary tensor.
55+
56+
### A16W8 WeightOnly Quantization
57+
58+
The `apply_weight_only_int8_quant` function swaps all
59+
linear modules to weight-only quantized linear modules.
60+
3261
Example
3362

3463
```
3564
import torch
3665
from torchao.quantization import quant_api
3766
38-
# some user model
67+
# some user model and example input
3968
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
40-
# some example input
4169
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')
4270
4371
# convert linear modules to quantized linear modules
44-
# insert quantization method/api of choice
4572
quant_api.apply_weight_only_int8_quant(model)
46-
# quant_api.apply_dynamic_quant(model)
47-
# quant_api.change_linear_weights_to_dqtensors(model)
4873
4974
# compile the model to improve performance
5075
torch.compile(model, mode='max-autotune')
5176
model(input)
5277
```
5378

54-
### A16W8 WeightOnly Quantization
79+
This technique works best when the torch._inductor.config.use_mixed_mm option is enabled. This avoids dequantizing the weight tensor before the matmul, instead fusing the dequantization into the matmul, thereby avoiding materialization of a large floating point weight tensor.
80+
81+
## Other APIs
82+
83+
### A8W8 Dynamic Quantization by subclasses
84+
85+
You can use [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) to do dynamic quantization with the `change_linear_weights_to_dqtensors` function using the exact same formula as above. This avoids modifying the graph and can be more composable with
86+
other techniques.
87+
88+
### A8W8 Dynamic Quantization with Smoothquant
89+
90+
We've also implemented a version of [smoothquant](https://arxiv.org/abs/2211.10438) with the same GEMM format as above.
91+
Due to requiring calibration, the API is slightly more complicated
92+
93+
Example
94+
95+
```
96+
import torch
97+
from torchao.smoothquant import swap_linear_with_smooth_fq_linear, smooth_fq_linear_to_inference
98+
99+
# some user model
100+
model = get_model()
101+
102+
# convert linear modules to smoothquant
103+
# linear module in calibration mode
104+
swap_linear_with_smooth_fq_linear(model)
105+
106+
# calibration
107+
for i in range(calibration_amount):
108+
input = get_input()
109+
model(input)
110+
111+
# set it to inference mode
112+
smooth_fq_linear_to_inference(model)
113+
114+
# compile the model to improve performance
115+
torch.compile(model, mode='max-autotune')
116+
model(input)
117+
```
55118

56119
## License
57120

test/test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
apply_dynamic_quant,
2222
apply_weight_only_int8_quant,
2323
change_linear_weights_to_dqtensors,
24+
_replace_with_custom_fn_if_matches_filter,
2425
)
2526
from torchao.quantization.quant_primitives import (
2627
dequantize_per_channel,
@@ -35,7 +36,6 @@
3536

3637
from torchao.quantization.smoothquant import (
3738
get_scale,
38-
replace_with_custom_fn_if_matches_filter,
3939
smooth_fq_linear_to_inference,
4040
SmoothFakeDynamicallyQuantizedLinear,
4141
swap_linear_with_smooth_fq_linear,
@@ -284,7 +284,7 @@ def test_selective_torch_compile(self):
284284
x = torch.randn(4, 4)
285285
y_ref = m(x)
286286

287-
replace_with_custom_fn_if_matches_filter(
287+
_replace_with_custom_fn_if_matches_filter(
288288
m,
289289
lambda mod: torch.compile(mod),
290290
lambda mod, fqn: isinstance(mod, nn.Linear) and fqn != "1.0",

torchao/quantization/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
__all__ = [
1515
"DynamicallyPerAxisQuantizedLinear",
16-
"replace_with_custom_fn_if_matches_filter",
1716
"apply_weight_only_int8_quant",
1817
"apply_dynamic_quant",
1918
"change_linear_weights_to_dqtensors",

torchao/quantization/dynamic_quant.py

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,56 +16,45 @@
1616

1717
class DynamicallyPerAxisQuantizedLinear(torch.nn.Linear):
1818
"""
19-
This class is a replacement for `torch.nn.Linear`, implementing dynamic quantization on
20-
the input across all axes except for the last axis.
19+
This class is a replacement for `torch.nn.Linear`. It implements a
20+
quantized matmul using int8 dynamic symmetric per-token activation,
21+
and int8 symmetric per-channel weight quantization
2122
"""
2223

2324
def __init__(
2425
self,
2526
in_features: int,
2627
out_features: int,
2728
bias: bool = True,
28-
use_fused_int_mm=False,
2929
) -> None:
3030
super().__init__(in_features, out_features, bias)
31-
self.use_fused_int_mm = use_fused_int_mm
32-
# note: enabling use_fused_int_mm = True has best perf when additionally setting
33-
# torch._inductor.config.force_fuse_int_mm_with_mul = True
3431

3532
def forward(self, X: torch.Tensor) -> torch.Tensor:
3633
"""
37-
Performs the forward pass of the quantized linear layer.
38-
39-
This method applies dynamic quantization to the input tensor across all axes except
40-
the last axis using the `quant_int8_dynamic_per_token_linear` function.
34+
Performs the forward pass of the quantized linear layer which consists
35+
of int8 dynamic symmetric per-token activation and int8 symmetric per-channel weight
36+
quantization
4137
4238
Args:
43-
X (torch.Tensor): The input tensor to the quantized linear layer.
39+
X (torch.Tensor): The input floating point tensor to the quantized linear layer.
4440
4541
Returns:
46-
torch.Tensor: The output tensor after the quantized matmul and rescale.
42+
torch.Tensor: The output floating point tensor after the quantized matmul and rescale.
4743
4844
"""
49-
# The following line mimics the behavior of SmoothFakeDynamicallyQuantizedLinear
50-
if not self.use_fused_int_mm:
51-
X = X / self.fake_rescale
52-
# somehow the inductor fusion that occurs for most transformer models
53-
# when this module has an additional div op is faster than when it doesn't
54-
# have it although the memory usage is slightly higher. fake_rescale is scalar 1
55-
# so it doesn't affect accuracy
45+
5646
Y = quant_int8_dynamic_per_token_linear(
5747
X, self.W_int_repr_t, self.W_scales, self.bias, X.dtype
5848
)
5949
return Y
6050

6151
@classmethod
6252
def from_float(
63-
cls, mod: torch.nn.Linear, use_fused_int_mm=False
53+
cls, mod: torch.nn.Linear
6454
) -> "DynamicallyPerAxisQuantizedLinear":
6555
"""
66-
Converts a `mod` of class `torch.nn.Linear` to the dynamically quantized version of it.
67-
68-
Note: this class does not require calibration.
56+
Converts a `mod` of class `torch.nn.Linear` to the
57+
`DynamicallyPerAxisQuantizedLinear` class
6958
7059
Args:
7160
mod (torch.nn.Linear): The original `torch.nn.Linear` module to convert.
@@ -81,7 +70,6 @@ def from_float(
8170
fake_in_features,
8271
fake_out_features,
8372
bias=mod.bias is not None,
84-
use_fused_int_mm=use_fused_int_mm,
8573
)
8674
new_mod.in_features = mod.in_features
8775
new_mod.out_features = mod.out_features
@@ -91,10 +79,6 @@ def from_float(
9179
new_mod.register_buffer("W_int_repr_t", W_int_repr.contiguous().t())
9280
new_mod.W_scales = nn.Parameter(W_scales)
9381
new_mod.bias = mod.bias
94-
if not use_fused_int_mm:
95-
new_mod.fake_rescale = torch.tensor(
96-
[1.0], dtype=mod.weight.dtype, device=mod.weight.device
97-
)
9882
del new_mod.weight
9983

10084
device_to_use = next(mod.parameters()).device

torchao/quantization/quant_api.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@
55
# LICENSE file in the root directory of this source tree.
66

77
"""
8-
Quantization API stuff which is not specific to SmoothQuant
8+
Quantization APIs
99
10-
Note: this is throwaway code for fast results on Blueberry, this is not
11-
intended to be the actual long term quantization API for server GPUs.
10+
Generally these APIs can be applied directly to any model
11+
with Linear modules to obtain quantized linear ops. The intended
12+
usage involves applying torch.compile to the model afterwards
13+
both because primitives were designed based on the fusions that
14+
come along with it and because that is how we access the intended quantized
15+
and mixed GEMM kernels
1216
"""
1317

1418
import torch
@@ -23,14 +27,13 @@
2327
)
2428

2529
__all__ = [
26-
"replace_with_custom_fn_if_matches_filter",
2730
"apply_weight_only_int8_quant",
2831
"apply_dynamic_quant",
2932
"change_linear_weights_to_dqtensors",
3033
]
3134

3235

33-
def replace_with_custom_fn_if_matches_filter(
36+
def _replace_with_custom_fn_if_matches_filter(
3437
model, replacement_fn, filter_fn, cur_fqn=""
3538
) -> None:
3639
"""
@@ -47,34 +50,41 @@ def replace_with_custom_fn_if_matches_filter(
4750
new_child = replacement_fn(child)
4851
setattr(model, name, new_child)
4952
else:
50-
replace_with_custom_fn_if_matches_filter(
53+
_replace_with_custom_fn_if_matches_filter(
5154
child, replacement_fn, filter_fn, new_fqn
5255
)
53-
54-
5556
def apply_weight_only_int8_quant(model):
56-
replace_with_custom_fn_if_matches_filter(
57+
"""
58+
Applies weight-only symmetric per-channel int8 quantization to all linear layers
59+
in the given model using module swaps.
60+
"""
61+
_replace_with_custom_fn_if_matches_filter(
5762
model,
5863
WeightOnlyInt8QuantLinear.from_float,
5964
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
6065
)
61-
62-
63-
def apply_dynamic_quant(model, use_fused_int_mm=0):
64-
replace_with_custom_fn_if_matches_filter(
66+
def apply_dynamic_quant(model):
67+
"""
68+
Applies dynamic symmetric per-token activation and per-channel weight
69+
quantization to all linear layers in the given model using
70+
module swaps.
71+
"""
72+
_replace_with_custom_fn_if_matches_filter(
6573
model,
66-
lambda mod: DynamicallyPerAxisQuantizedLinear.from_float(mod, use_fused_int_mm),
74+
lambda mod: DynamicallyPerAxisQuantizedLinear.from_float(mod),
6775
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
6876
)
69-
70-
7177
def change_linear_weights_to_dqtensors(model):
78+
"""
79+
Converts all linear weight tensors to the `DynamicallyQuantizedLinearWeight`
80+
Tensor subclass, effectively applying the same form of quantization
81+
as apply_dynamic_quant while not modifying the linear modules.
82+
"""
7283
def insert_subclass(lin):
7384
lin.weight = torch.nn.Parameter(
7485
DynamicallyQuantizedLinearWeight.from_float(lin.weight), requires_grad=False
7586
)
7687
return lin
77-
78-
replace_with_custom_fn_if_matches_filter(
88+
_replace_with_custom_fn_if_matches_filter(
7989
model, insert_subclass, lambda mod, fqn: isinstance(mod, torch.nn.Linear)
8090
)

torchao/quantization/quant_primitives.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -303,14 +303,13 @@ def quant_int8_dynamic_per_token_linear(
303303
w_vals_int8_t,
304304
w_scales,
305305
bias,
306-
out_dtype=torch.float32,
307-
use_fused_int_mm=0,
306+
out_dtype,
308307
):
309308
# like F.linear, but with int8 dynamic quantization of activation,
310309
# and a quantized weight
311310
x_vals_int8, x_scales = quantize_activation_per_token_absmax(x)
312311
mm_out = quant_int8_per_token_matmul(
313-
x_vals_int8, x_scales, w_vals_int8_t, w_scales, out_dtype, use_fused_int_mm
312+
x_vals_int8, x_scales, w_vals_int8_t, w_scales, out_dtype
314313
)
315314
if bias is not None:
316315
mm_out += bias
@@ -323,7 +322,6 @@ def quant_int8_per_token_matmul(
323322
w_vals_int8_t,
324323
w_scales,
325324
output_dtype=torch.float32,
326-
use_fused_int_mm=0,
327325
):
328326
# Quantized matmul of int8 operands that accumulates to int32 and returns
329327
# output_dtype. For now, this is written for approximate numerical
@@ -355,18 +353,6 @@ def quant_int8_per_token_matmul(
355353
#
356354

357355
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
358-
# these branches use external triton fused_int_mm kernel's which fuse either 1 or 2 mul operations
359-
if use_fused_int_mm == 2:
360-
y = torch.ops.custom_int_mm.int_mm_dequant(
361-
tmp, w_vals_int8_t, x_scales.view(-1, 1), w_scales, output_dtype
362-
).reshape(*x_vals_int8.shape[:-1], -1)
363-
return y
364-
elif use_fused_int_mm == 1:
365-
y = torch.ops.custom_int_mm.int_mm_one_mul(
366-
tmp, w_vals_int8_t, x_scales.view(-1, 1), output_dtype
367-
).reshape(*x_vals_int8.shape[:-1], -1)
368-
y = y * w_scales
369-
return y.to(output_dtype)
370356
y_dot_int32 = safe_int_mm(tmp, w_vals_int8_t)
371357

372358
#
@@ -381,6 +367,7 @@ def quant_int8_per_token_matmul(
381367
torch.float,
382368
torch.bfloat16,
383369
], f"x_scales needs to be a torch.float32 or torch.bfloat16 but got {x_scales.dtype}"
370+
384371
y = (y_dot_int32 * x_scales.view(-1, 1) * w_scales).reshape(
385372
*x_vals_int8.shape[:-1], -1
386373
)

0 commit comments

Comments
 (0)