|
4 | 4 | # LICENSE file in the root directory of this source tree. |
5 | 5 |
|
6 | 6 | # pyre-unsafe |
7 | | -from typing import List |
| 7 | +from typing import Any, List |
8 | 8 |
|
| 9 | +import numpy as np |
9 | 10 | import torch |
10 | 11 |
|
11 | | -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore |
12 | | - |
13 | 12 | from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( |
14 | 13 | get_input_qparams, |
15 | 14 | get_output_qparams, |
|
19 | 18 | register_node_visitor, |
20 | 19 | ) |
21 | 20 | from executorch.backends.arm.tosa_mapping import TosaArg |
22 | | -from executorch.backends.arm.tosa_quant_utils import build_rescale_conv_output |
| 21 | +from executorch.backends.arm.tosa_quant_utils import build_rescale, build_rescale_v0_80 |
| 22 | +from executorch.backends.arm.tosa_specification import TosaSpecification |
23 | 23 | from executorch.backends.arm.tosa_utils import build_reshape, tosa_shape |
24 | 24 |
|
25 | 25 |
|
26 | 26 | @register_node_visitor |
27 | | -class Conv2dVisitor(NodeVisitor): |
| 27 | +class Conv2dVisitor_0_80(NodeVisitor): |
28 | 28 | target = "aten.convolution.default" |
29 | 29 |
|
| 30 | + tosa_specs = [ |
| 31 | + TosaSpecification.create_from_string("TOSA-0.80+BI"), |
| 32 | + TosaSpecification.create_from_string("TOSA-0.80+MI"), |
| 33 | + ] |
| 34 | + |
30 | 35 | def __init__(self, *args): |
31 | 36 | super().__init__(*args) |
32 | 37 |
|
@@ -54,10 +59,13 @@ def adjust_pad_if_needed( |
54 | 59 | def define_node( |
55 | 60 | self, |
56 | 61 | node: torch.fx.Node, |
57 | | - tosa_graph: ts.TosaSerializer, |
| 62 | + tosa_graph: Any, |
58 | 63 | inputs: List[TosaArg], |
59 | 64 | output: TosaArg, |
60 | 65 | ) -> None: |
| 66 | + |
| 67 | + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore |
| 68 | + |
61 | 69 | input, weight, bias, stride, pad, dilation, _, _, group = inputs |
62 | 70 |
|
63 | 71 | # Get the attributes of convolution. |
@@ -170,14 +178,224 @@ def define_node( |
170 | 178 | input_scale = input_qparams[0].scale # type: ignore[possibly-undefined] # pyre-ignore [61] |
171 | 179 | weight_scale = input_qparams[1].scale # pyre-ignore [61] |
172 | 180 | output_qargs = get_output_qparams(node) |
173 | | - build_rescale_conv_output( |
174 | | - tosa_graph, |
175 | | - # pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined. |
176 | | - conv2d_res, # type: ignore[possibly-undefined] |
177 | | - output.name, |
178 | | - output.dtype, |
179 | | - [input_scale], |
180 | | - [weight_scale], |
181 | | - [output_qargs[0].scale], |
182 | | - output_qargs[0].zp, |
| 181 | + post_conv2d_scale = [ |
| 182 | + (inp * w) / out |
| 183 | + for inp, w, out in zip( |
| 184 | + [input_scale], [weight_scale], [output_qargs[0].scale] |
| 185 | + ) |
| 186 | + ] |
| 187 | + |
| 188 | + build_rescale_v0_80( |
| 189 | + tosa_fb=tosa_graph, |
| 190 | + scale=post_conv2d_scale, |
| 191 | + input_node=conv2d_res, # type: ignore[possibly-undefined] |
| 192 | + output_name=output.name, |
| 193 | + output_type=output.dtype, |
| 194 | + input_zp=0, |
| 195 | + output_zp=output_qargs[0].zp, |
| 196 | + per_channel=isinstance(weight_scale, torch.Tensor), |
| 197 | + ) # type: ignore[call-arg] |
| 198 | + |
| 199 | + |
| 200 | +@register_node_visitor |
| 201 | +class Conv2dVisitor(NodeVisitor): |
| 202 | + target = "aten.convolution.default" |
| 203 | + |
| 204 | + tosa_specs = [ |
| 205 | + TosaSpecification.create_from_string("TOSA-1.0+INT"), |
| 206 | + TosaSpecification.create_from_string("TOSA-1.0+FP"), |
| 207 | + ] |
| 208 | + |
| 209 | + def __init__(self, *args): |
| 210 | + super().__init__(*args) |
| 211 | + |
| 212 | + # torch.nn.Conv2d does not require the result of |
| 213 | + # `(input + 2 * pad - dilation * (weight - 1) - 1) / stride` |
| 214 | + # to be an integer, but tosa currently strictly require this property. |
| 215 | + # This function adjusts the pad value to meet the requirement. |
| 216 | + def adjust_pad_if_needed( |
| 217 | + self, input_size: int, input_weight: int, stride: int, pad: int, dilation: int |
| 218 | + ) -> int: |
| 219 | + mod_remainder = ( |
| 220 | + input_size + 2 * pad - dilation * (input_weight - 1) - 1 |
| 221 | + ) % stride |
| 222 | + |
| 223 | + # No need to adjust |
| 224 | + if mod_remainder == 0: |
| 225 | + return pad |
| 226 | + |
| 227 | + if mod_remainder > pad: |
| 228 | + raise RuntimeError( |
| 229 | + "This case should be handled by the SizeAdjustConv2d pass, is it enabled?" |
| 230 | + ) |
| 231 | + return pad - mod_remainder |
| 232 | + |
| 233 | + def define_node( |
| 234 | + self, |
| 235 | + node: torch.fx.Node, |
| 236 | + tosa_graph: Any, |
| 237 | + inputs: List[TosaArg], |
| 238 | + output: TosaArg, |
| 239 | + ) -> None: |
| 240 | + |
| 241 | + import serializer.tosa_serializer as ts # type: ignore |
| 242 | + from tosa.RoundingMode import RoundingMode # type: ignore |
| 243 | + |
| 244 | + input, weight, bias, stride, pad, dilation, _, _, group = inputs |
| 245 | + |
| 246 | + # Get the attributes of convolution. |
| 247 | + attr = ts.TosaSerializerAttribute() |
| 248 | + pad_attr = [val for val in pad.special for _ in (0, 1)] |
| 249 | + stride_attr = stride.special |
| 250 | + dilation_attr = dilation.special |
| 251 | + |
| 252 | + # Adjust the pad value if needed to meet the |
| 253 | + # strict convolution output shape calculation. |
| 254 | + pad_attr[1] = self.adjust_pad_if_needed( |
| 255 | + input.shape[2], |
| 256 | + weight.shape[2], |
| 257 | + stride_attr[0], |
| 258 | + pad_attr[1], |
| 259 | + dilation_attr[0], |
| 260 | + ) |
| 261 | + pad_attr[3] = self.adjust_pad_if_needed( |
| 262 | + input.shape[3], |
| 263 | + weight.shape[3], |
| 264 | + stride_attr[1], |
| 265 | + pad_attr[3], |
| 266 | + dilation_attr[1], |
| 267 | + ) |
| 268 | + |
| 269 | + input_zp = 0 |
| 270 | + if inputs[0].dtype == ts.DType.INT8: |
| 271 | + # int8 input requires quantization information |
| 272 | + input_qparams = get_input_qparams(node) |
| 273 | + input_zp = input_qparams[0].zp |
| 274 | + |
| 275 | + tosa_graph.addConst([1], output.dtype, [input_zp], name=f"{node.name}_input_zp") |
| 276 | + tosa_graph.addConst([1], output.dtype, [0], name=f"{node.name}_weight_zp") |
| 277 | + acc_type = ( |
| 278 | + inputs[0].dtype if inputs[0].dtype == ts.DType.FP32 else ts.DType.INT32 |
| 279 | + ) |
| 280 | + |
| 281 | + # Non-bias case. |
| 282 | + if len(node.all_input_nodes) == 2: |
| 283 | + # Create a zero bias tensor if not presented |
| 284 | + out_channels = weight.shape[0] |
| 285 | + bias_name = "bias" + node.name.split("default", 1)[1] |
| 286 | + bias_type = output.dtype |
| 287 | + if output.dtype == ts.DType.INT8: |
| 288 | + # Conv is quantized to int8, but the TOSA operator has |
| 289 | + # output type int32, and the bias must be the same type |
| 290 | + # as the TOSA output type |
| 291 | + bias_type = ts.DType.INT32 |
| 292 | + bias = tosa_graph.addConst( |
| 293 | + [out_channels], |
| 294 | + bias_type, |
| 295 | + [0] * out_channels, |
| 296 | + name=bias_name, |
| 297 | + ) |
| 298 | + |
| 299 | + # The output type is int32 when input type is int8. |
| 300 | + conv2d_output_name = output.name |
| 301 | + if output.dtype == ts.DType.INT8: |
| 302 | + conv2d_res = tosa_graph.addIntermediate( |
| 303 | + tosa_shape(output.shape, output.dim_order), ts.DType.INT32 |
| 304 | + ) |
| 305 | + conv2d_output_name = conv2d_res.name |
| 306 | + |
| 307 | + # Given input.shape is (N, Ci, H, W), and weight.shape is (Co, Ci/G, H, W) |
| 308 | + in_channels = input.shape[1] |
| 309 | + out_channels = weight.shape[0] |
| 310 | + if (in_channels == group.number) and (out_channels % in_channels) == 0: |
| 311 | + """Depthwise convolution case""" |
| 312 | + # Reshape torch shape format of weight tensor to tosa required format. |
| 313 | + # https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d |
| 314 | + m_length = int(out_channels / in_channels) |
| 315 | + weight_post_shape = [ |
| 316 | + weight.shape[2], |
| 317 | + weight.shape[3], |
| 318 | + in_channels, |
| 319 | + m_length, |
| 320 | + ] |
| 321 | + |
| 322 | + weight_reshaped = tosa_graph.addIntermediate( |
| 323 | + weight_post_shape, |
| 324 | + weight.dtype, |
| 325 | + ) |
| 326 | + shape = tosa_graph.addConst( |
| 327 | + np.array(weight_post_shape).shape, |
| 328 | + ts.DType.SHAPE, |
| 329 | + np.array(weight_post_shape), |
| 330 | + name=weight_reshaped.name + "_shape", |
| 331 | + ) |
| 332 | + |
| 333 | + attr = ts.TosaSerializerAttribute() |
| 334 | + attr.ReshapeAttribute() |
| 335 | + tosa_graph.addOperator( |
| 336 | + ts.TosaOp.Op().RESHAPE, |
| 337 | + [weight.name, shape.name], |
| 338 | + [weight_reshaped.name], |
| 339 | + attr, |
| 340 | + ) |
| 341 | + |
| 342 | + tosa_op = ts.TosaOp.Op().DEPTHWISE_CONV2D |
| 343 | + weight_name = weight_reshaped.name |
| 344 | + |
| 345 | + attr.DepthwiseConv2dAttribute( |
| 346 | + pad=pad_attr, |
| 347 | + stride=stride_attr, |
| 348 | + dilation=dilation_attr, |
| 349 | + local_bound=False, |
| 350 | + acc_type=acc_type, |
| 351 | + ) |
| 352 | + else: |
| 353 | + """Regular convolution case""" |
| 354 | + tosa_op = ts.TosaOp.Op().CONV2D |
| 355 | + weight_name = weight.name |
| 356 | + |
| 357 | + attr.Conv2dAttribute( |
| 358 | + pad=pad_attr, |
| 359 | + stride=stride_attr, |
| 360 | + dilation=dilation_attr, |
| 361 | + local_bound=False, |
| 362 | + acc_type=acc_type, |
| 363 | + ) |
| 364 | + |
| 365 | + tosa_graph.addOperator( |
| 366 | + tosa_op, |
| 367 | + [ |
| 368 | + input.name, |
| 369 | + weight_name, |
| 370 | + bias.name, |
| 371 | + f"{node.name}_input_zp", |
| 372 | + f"{node.name}_weight_zp", |
| 373 | + ], |
| 374 | + [conv2d_output_name], |
| 375 | + attr, |
| 376 | + ) |
| 377 | + |
| 378 | + # For quantized convolution, rescale the output value back to the same |
| 379 | + # integer value domain of the next op. Otherwise return float32 output. |
| 380 | + if inputs[0].dtype == ts.DType.INT8: |
| 381 | + # Get scale_factor from input, weight, and output. |
| 382 | + input_scale = input_qparams[0].scale # type: ignore[possibly-undefined] # pyre-ignore [61] |
| 383 | + weight_scale = input_qparams[1].scale # pyre-ignore [61] |
| 384 | + output_qargs = get_output_qparams(node) |
| 385 | + post_conv2d_scale = [ |
| 386 | + (inp * w) / out |
| 387 | + for inp, w, out in zip( |
| 388 | + [input_scale], [weight_scale], [output_qargs[0].scale] |
| 389 | + ) |
| 390 | + ] |
| 391 | + build_rescale( |
| 392 | + tosa_fb=tosa_graph, |
| 393 | + scale=post_conv2d_scale, |
| 394 | + input_node=conv2d_res, # type: ignore[possibly-undefined] |
| 395 | + output_name=output.name, |
| 396 | + output_type=output.dtype, |
| 397 | + input_zp=0, |
| 398 | + output_zp=output_qargs[0].zp, |
| 399 | + per_channel=isinstance(weight_scale, torch.Tensor), |
| 400 | + rounding_mode=RoundingMode.SINGLE_ROUND, |
183 | 401 | ) |
0 commit comments