Skip to content

Commit 1bd3613

Browse files
committed
Update on "Documentation Updates"
Summary: Updating README with better examples, updating class and api documentation and removing the unnecessary int_mm_fused_mul option from dynamic quant Test Plan: python test/test.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent d375485 commit 1bd3613

File tree

1 file changed

+24
-16
lines changed

1 file changed

+24
-16
lines changed

README.md

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,49 +27,57 @@ torchao 0.0.1 <install dir>
2727

2828
Relevant APIs can be found in torchao.quantization.quant_api
2929

30+
Note: Depending on the technique being applied to the model, you may see a perf degredation.
31+
This is because quantization adds additional overhead to the model that is hopefully made up for
32+
with faster matmuls. If your matmuls are small enough (or have odd shapes), the overhead can be larger than the gain
33+
from the quantized matmul.
3034

31-
### A16W8 WeightOnly Quantization
35+
### A8W8 Dynamic Quantization
3236

33-
The `apply_weight_only_int8_quant` function swaps all
34-
linear modules to weight-only quantized linear modules.
37+
Similar to the weight only api above, the `apply_dynamic_quant` function swaps all
38+
linear modules to dynamically quantized quantized linear modules.
3539

3640
Example
3741

3842
```
39-
import torch
40-
from torchao.quantization import quant_api
4143
4244
# some user model and example input
43-
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
44-
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')
45+
...
4546
4647
# convert linear modules to quantized linear modules
47-
quant_api.apply_weight_only_int8_quant(model)
48+
quant_api.apply_dynamic_quant(model)
4849
4950
# compile the model to improve performance
50-
torch.compile(model, mode='max-autotune')
51-
model(input)
51+
...
5252
```
5353

54-
### A8W8 Dynamic Quantization
54+
This technique works best when the torch._inductor.config.force_fuse_int_mm_with_mul option is enabled. This allows fusion of the int8*int8 -> int32 matmul and subsequent mul op, thereby avoiding materialization of the int32 intermediary tensor.
5555

56-
Similar to the weight only api above, the `apply_dynamic_quant` function swaps all
57-
linear modules to dynamically quantized quantized linear modules.
56+
### A16W8 WeightOnly Quantization
57+
58+
The `apply_weight_only_int8_quant` function swaps all
59+
linear modules to weight-only quantized linear modules.
5860

5961
Example
6062

6163
```
64+
import torch
65+
from torchao.quantization import quant_api
6266
6367
# some user model and example input
64-
...
68+
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
69+
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')
6570
6671
# convert linear modules to quantized linear modules
67-
quant_api.apply_dynamic_quant(model)
72+
quant_api.apply_weight_only_int8_quant(model)
6873
6974
# compile the model to improve performance
70-
...
75+
torch.compile(model, mode='max-autotune')
76+
model(input)
7177
```
7278

79+
This technique works best when the torch._inductor.config.use_mixed_mm option is enabled. This avoids dequantizing the weight tensor before the matmul, instead fusing the dequantization into the matmul, thereby avoiding materialization of a large floating point weight tensor.
80+
7381
## Other APIs
7482

7583
### A8W8 Dynamic Quantization by subclasses

0 commit comments

Comments
 (0)