Commit 5bf0fef
authored
Extend legalize-quant-to-math pass to support composite op (#2723)
**Note to the reviewers:**
This PR is based on #2722. The
changes related to this PR is localized in
`stablehlo/transforms/StablehloLegalizeQuantToMath.cpp` and
`stablehlo/tests/transforms/stablehlo_legalize_quant_to_int.mlir` files
only.
## Summary
The `quant-to-math` legalization of composite op can be realized as:
1. Apply legalization to its decomposition.
1. Convert the quantized signature of the composite op to the integer
signature.
Note that both 1 and 2 are achieved __almost for free__ by the existing
patterns.
* (1) By virtue of the fact that the existing pass applies to every func
in module
* (2) As part of
[ConvertGenericOp](https://github.com/openxla/stablehlo/blob/03597b1e592129f0c79e99e5ed65dac7ebee240f/stablehlo/transforms/StablehloLegalizeQuantToMath.cpp#L1310)
conversion pattern.
Together with #2722, we can do
something like
### Step 1
```
$ cat input.mlir
func.func @decompose_composite_op(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> tensor<2xf32> {
%0 = stablehlo.uniform_quantize %arg0 : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 0.1:2>>
%1 = stablehlo.uniform_quantize %arg1 : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 0.1:2>>
%2 = stablehlo.add %0, %1 : tensor<2x!quant.uniform<i8:f32, 0.1:2>>
%3 = stablehlo.uniform_dequantize %2 : (tensor<2x!quant.uniform<i8:f32, 0.1:2>>) -> tensor<2xf32>
return %3 : tensor<2xf32>
}
```
### Step 2: Apply
https://github.com/openxla/stablehlo/blob/3a0cd9d12166d8426777206339b8562be64c55bc/stablehlo/transforms/Passes.td#L413
pass
As part of applying the pass, we are providing the attribute names for
quantize/dequantize composites. With that we may get something like
```mlir
func.func @decompose_composite_ops(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
%0 = stablehlo.composite "stablehlo.uniform_quantize" %arg0 {composite_attributes = {expressed_type = f32, scale = 1.000000e-01 : f64, storage_type = i8, storage_type_max = 127 : i64, storage_type_min = -128 : i64, zero_point = 0 : i64}, decomposition = @stablehlo.uniform_quantize.impl_0} : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>
%1 = stablehlo.composite "stablehlo.uniform_quantize" %arg1 {composite_attributes = {expressed_type = f32, scale = 1.000000e-01 : f64, storage_type = i8, storage_type_max = 127 : i64, storage_type_min = -128 : i64, zero_point = 0 : i64}, decomposition = @stablehlo.uniform_quantize.impl} : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>
%2 = stablehlo.composite "stablehlo.add" %0, %1 {decomposition = @stablehlo.add.impl} : (tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>, tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>) -> tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>
%3 = stablehlo.composite "stablehlo.uniform_dequantize" %2 {composite_attributes = {expressed_type = f32, scale = 1.000000e-01 : f64, storage_type = i8, storage_type_max = 127 : i64, storage_type_min = -128 : i64, zero_point = 0 : i64}, decomposition = @stablehlo.uniform_dequantize.impl} : (tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>) -> tensor<2xf32>
return %3 : tensor<2xf32>
}
func.func private @stablehlo.uniform_dequantize.impl(%arg0: tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>) -> tensor<2xf32> {
%0 = stablehlo.uniform_dequantize %arg0 : (tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
func.func private @stablehlo.add.impl(%arg0: tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>, %arg1: tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>) -> tensor<2x!quant.uniform<i8:f32, 1.000000e-01>> {
%0 = stablehlo.add %arg0, %arg1 : tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>
return %0 : tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>
}
func.func private @stablehlo.uniform_quantize.impl(%arg0: tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 1.000000e-01>> {
%0 = stablehlo.uniform_quantize %arg0 : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>
return %0 : tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>
}
func.func private @stablehlo.uniform_quantize.impl_0(%arg0: tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 1.000000e-01>> {
%0 = stablehlo.uniform_quantize %arg0 : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>
return %0 : tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>
}
```
### Step 3: Apply `stablehlo-legalize-quant-to-int` pass
```mlir
func.func @decompose_composite_ops(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
%0 = stablehlo.composite "stablehlo.uniform_quantize" %arg0 {composite_attributes = {expressed_type = f32, scale = 1.000000e-01 : f64, storage_type = i8, storage_type_max = 127 : i64, storage_type_min = -128
: i64, zero_point = 0 : i64}, decomposition = @stablehlo.uniform_quantize.impl_0} : (tensor<2xf32>) -> tensor<2xi8>
%1 = stablehlo.composite "stablehlo.uniform_quantize" %arg1 {composite_attributes = {expressed_type = f32, scale = 1.000000e-01 : f64, storage_type = i8, storage_type_max = 127 : i64, storage_type_min = -128
: i64, zero_point = 0 : i64}, decomposition = @stablehlo.uniform_quantize.impl} : (tensor<2xf32>) -> tensor<2xi8>
%2 = stablehlo.composite "stablehlo.add" %0, %1 {decomposition = @stablehlo.add.impl} : (tensor<2xi8>, tensor<2xi8>) -> tensor<2xi8>
%3 = stablehlo.composite "stablehlo.uniform_dequantize" %2 {composite_attributes = {expressed_type = f32, scale = 1.000000e-01 : f64, storage_type = i8, storage_type_max = 127 : i64, storage_type_min = -128
: i64, zero_point = 0 : i64}, decomposition = @stablehlo.uniform_dequantize.impl} : (tensor<2xi8>) -> tensor<2xf32>
return %3 : tensor<2xf32>
}
func.func private @stablehlo.uniform_dequantize.impl(%arg0: tensor<2xi8>) -> tensor<2xf32> {
// ... decomposition of stablehlo.uniform_dequantize
}
func.func private @stablehlo.add.impl(%arg0: tensor<2xi8>, %arg1: tensor<2xi8>) -> tensor<2xi8> {
// decomposition of quantized stablehlo.add
}
func.func private @stablehlo.uniform_quantize.impl(%arg0: tensor<2xf32>) -> tensor<2xi8> {
// ... decomposition of stablehlo.uniform_quantize
}
func.func private @stablehlo.uniform_quantize.impl_0(%arg0: tensor<2xf32>) -> tensor<2xi8> {
// ... decomposition of stablehlo.uniform_quantize
}
```
cc @mahmoud-abuzaina1 parent 66f90d5 commit 5bf0fef
File tree
2 files changed
+63
-7
lines changed- stablehlo
- tests/transforms
- transforms
2 files changed
+63
-7
lines changedLines changed: 55 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2745 | 2745 | | |
2746 | 2746 | | |
2747 | 2747 | | |
| 2748 | + | |
| 2749 | + | |
| 2750 | + | |
| 2751 | + | |
| 2752 | + | |
| 2753 | + | |
| 2754 | + | |
| 2755 | + | |
| 2756 | + | |
| 2757 | + | |
| 2758 | + | |
| 2759 | + | |
| 2760 | + | |
| 2761 | + | |
| 2762 | + | |
| 2763 | + | |
| 2764 | + | |
| 2765 | + | |
| 2766 | + | |
| 2767 | + | |
| 2768 | + | |
| 2769 | + | |
| 2770 | + | |
| 2771 | + | |
| 2772 | + | |
| 2773 | + | |
| 2774 | + | |
| 2775 | + | |
| 2776 | + | |
| 2777 | + | |
| 2778 | + | |
| 2779 | + | |
| 2780 | + | |
| 2781 | + | |
| 2782 | + | |
| 2783 | + | |
| 2784 | + | |
| 2785 | + | |
| 2786 | + | |
| 2787 | + | |
| 2788 | + | |
| 2789 | + | |
| 2790 | + | |
| 2791 | + | |
| 2792 | + | |
| 2793 | + | |
| 2794 | + | |
| 2795 | + | |
| 2796 | + | |
| 2797 | + | |
| 2798 | + | |
| 2799 | + | |
| 2800 | + | |
| 2801 | + | |
| 2802 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1318 | 1318 | | |
1319 | 1319 | | |
1320 | 1320 | | |
1321 | | - | |
1322 | | - | |
1323 | | - | |
1324 | | - | |
1325 | | - | |
1326 | | - | |
1327 | | - | |
| 1321 | + | |
| 1322 | + | |
| 1323 | + | |
| 1324 | + | |
| 1325 | + | |
| 1326 | + | |
| 1327 | + | |
| 1328 | + | |
1328 | 1329 | | |
1329 | 1330 | | |
1330 | 1331 | | |
| |||
0 commit comments