|
10 | 10 | from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter |
11 | 11 | from executorch.backends.qualcomm.utils.constants import ( |
12 | 12 | QCOM_AXIS, |
| 13 | + QCOM_BLOCK_SIZE, |
13 | 14 | QCOM_DTYPE, |
14 | 15 | QCOM_ENCODING, |
15 | 16 | QCOM_QUANT_ATTRS, |
@@ -122,13 +123,25 @@ def _dequant_fold_params(self, n, quant_attrs, param): |
122 | 123 | scales = self._expand(quant_attrs[QCOM_SCALES], dim, axis) |
123 | 124 | offsets = self._expand(quant_attrs[QCOM_ZERO_POINTS], dim, axis) |
124 | 125 | param = param.sub(offsets).mul(scales).to(torch.float32).contiguous() |
125 | | - set_parameter(param, n.args[0], self.edge_program) |
| 126 | + elif quant_attrs[QCOM_ENCODING] in [ |
| 127 | + exir_ops.edge.pt2e_quant.dequantize_affine.default |
| 128 | + ]: |
| 129 | + param = torch.ops.pt2e_quant.dequantize_affine( |
| 130 | + param, |
| 131 | + block_size=quant_attrs[QCOM_BLOCK_SIZE], |
| 132 | + scale=quant_attrs[QCOM_SCALE], |
| 133 | + zero_point=quant_attrs[QCOM_ZERO_POINT], |
| 134 | + input_dtype=quant_attrs[QCOM_DTYPE], |
| 135 | + quant_min=quant_attrs[QCOM_QUANT_MIN], |
| 136 | + quant_max=quant_attrs[QCOM_QUANT_MAX], |
| 137 | + output_dtype=torch.float32, |
| 138 | + ) |
126 | 139 | else: |
127 | 140 | scale = quant_attrs[QCOM_SCALE] |
128 | 141 | offset = quant_attrs[QCOM_ZERO_POINT] |
129 | 142 | param = param.sub(offset).mul(scale).to(torch.float32).contiguous() |
130 | | - set_parameter(param, n.args[0], self.edge_program) |
131 | 143 |
|
| 144 | + set_parameter(param, n.args[0], self.edge_program) |
132 | 145 | n.args[0].meta["val"] = param |
133 | 146 |
|
134 | 147 | def _annotate_quant_attrs( |
|
0 commit comments