Skip to content

Commit 1db9e08

Browse files
jerryzh168facebook-github-bot
authored andcommitted
Clarify the output of quantization flow
Summary: att Reviewed By: kimishpatel Differential Revision: D48232772 fbshipit-source-id: 1c8c1fd2bd9c841151294b2c4f2c3d84cb7b9bc7
1 parent 49ef23c commit 1db9e08

File tree

1 file changed

+28
-3
lines changed

1 file changed

+28
-3
lines changed

docs/website/docs/tutorials/quantization_flow.md

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,44 @@ Note: Before quantizing models, each backend need to implement their own `Quanti
2020
Please take a look at the [pytorch 2.0 export post training static quantization tutorial](https://pytorch.org/tutorials/prototype/pt2e_quant_ptq_static.html) to learn about all the steps of quantization. Main APIs that's used to quantize the model would be:
2121
* `prepare_pt2e`: used to insert observers to the model, it takes a backend specific `Quantizer` as argument, which will annotate the nodes with informations needed to quantize the model properly for the backend
2222
* (not an api) calibration: run the model through some sample data
23-
* `convert_pt2e`: convert a observed model to a quantized model, we have special representation for selected ops (e.g. quantized linear), other ops are represented as (dq -> float32_op -> q), and q/dq are decomposed into more primitive operators.
23+
* `convert_pt2e`: convert a observed model to a quantized model.
24+
2425

2526
### Result
2627
The result after these steps will be a reference quantized model, with quantize/dequantize operators being further decomposed. Example:
2728

29+
#### Q/DQ Representation (default)
30+
We'll have (dq -> float32_op -> q) representation for all quantized operators
31+
32+
```
33+
def quantized_linear(x_int8, x_scale, x_zero_point, weight_int8, weight_scale, weight_zero_point, bias_fp32, output_scale, output_zero_point):
34+
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
35+
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
36+
weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
37+
weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
38+
weight_permuted = torch.ops.aten.permute_copy.default(weight_fp32, [1, 0]);
39+
out_fp32 = torch.ops.aten.addmm.default(bias_fp32, x_fp32, weight_permuted)
40+
out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
41+
out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8)
42+
return out_i8
43+
```
44+
45+
46+
#### Reference Quantized Model Representation
47+
(WIP, expected to be ready at end of August): we have special representation for selected ops (e.g. quantized linear), other ops are represented as (dq -> float32_op -> q), and q/dq are decomposed into more primitive operators.
48+
49+
You can get this representation by:
50+
`convert_pt2e(..., use_reference_representation=True)`
51+
2852
```
2953
# Reference Quantized Pattern for quantized linear
30-
def quantized_linear(x_int8, x_scale, x_zero_point, weight_int8, weight_scale, weight_zero_point, bias_int32, bias_scale, bias_zero_point, output_scale, output_zero_point):
54+
def quantized_linear(x_int8, x_scale, x_zero_point, weight_int8, weight_scale, weight_zero_point, bias_fp32, output_scale, output_zero_point):
3155
x_int16 = x_int8.to(torch.int16)
3256
weight_int16 = weight_int8.to(torch.int16)
3357
acc_int32 = torch.ops.out_dtype(torch.mm, torch.int32, (x_int16 - x_zero_point), (weight_int16 - weight_zero_point))
3458
acc_rescaled_int32 = torch.ops.out_dtype(torch.ops.aten.mul.Scalar, torch.int32, acc_int32, x_scale * weight_scale / output_scale)
35-
bias_int32 = torch.ops.out_dtype(torch.ops.aten.mul.Scalar, bias_int32 - bias_zero_point, bias_scale / output_scale))
59+
bias_scale = x_scale * weight_scale
60+
bias_int32 = out_dtype(torch.ops.aten.mul.Tensor, torch.int32, bias_fp32, bias_scale / out_scale)
3661
out_int8 = torch.ops.aten.clamp(acc_rescaled_int32 + bias_int32 + output_zero_point, qmin, qmax).to(torch.int8)
3762
return out_int8
3863
```

0 commit comments

Comments
 (0)