Skip to content

Commit 5b1b565

Browse files
committed
Support Transposed Convolution in XNNPACK delegate
1 parent ae3d558 commit 5b1b565

16 files changed

+654
-38
lines changed

backends/xnnpack/_passes/fuse_activation_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def call(self, graph_module: torch.fx.GraphModule):
6868
preceding_op.op == "call_function"
6969
and preceding_op.target in self.FUSEABLE_OPS
7070
):
71+
if len(preceding_op.users) > 1:
72+
continue
7173
# Delete activation, and embed metadata into preceding op
7274
output_min_max = self.get_output_min_max_from_activation(
7375
activation_node

backends/xnnpack/_passes/fuse_batch_norm_with_conv.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,18 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import operator
8+
from typing import cast, List
89

910
import torch
1011

12+
from executorch.backends.transforms import get_shape
1113
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
1214

13-
from executorch.backends.xnnpack.utils.utils import get_param_tensor, is_param_node
15+
from executorch.backends.xnnpack.utils.utils import (
16+
get_input_node,
17+
get_param_tensor,
18+
is_param_node,
19+
)
1420
from executorch.exir import ExportedProgram
1521
from executorch.exir.dialects._ops import ops as exir_ops
1622
from executorch.exir.pass_base import PassResult
@@ -134,6 +140,16 @@ def can_fuse(
134140
Determine whether a batch norm node can be fused with a preceding conv node.
135141
"""
136142

143+
is_transpose = conv.args[6]
144+
kernel_node = get_input_node(conv, 1)
145+
kernel_shape = get_shape(kernel_node)
146+
stride = cast(List[int], conv.args[3])
147+
148+
if is_transpose and (
149+
kernel_shape[-1] != stride[0] or kernel_shape[-2] != stride[1]
150+
):
151+
return False
152+
137153
# All the users of batchnorm node must be getitem ops. batchnorm
138154
# returns a 3-element tuple. Each user must only access the first
139155
# element of the tuple.

backends/xnnpack/operators/node_visitor.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def _check_per_channel_group_params(
337337
# For now group quantization is only supported for 4b weights
338338
assert quant_params.is_qc4w, "Only 4b group quantization is supported"
339339

340-
def define_tensor(
340+
def define_tensor( # noqa: C901
341341
self,
342342
tensor: torch.fx.Node,
343343
xnn_graph: XNNGraph,
@@ -346,6 +346,8 @@ def define_tensor(
346346
swap_nc_for_depthwise_weights: bool = False,
347347
quant_params: Optional[QuantParams] = None,
348348
fp32_static_weights: bool = False,
349+
swap_in_out_for_transpose_weights: bool = False,
350+
groups: int = 1,
349351
) -> None:
350352
"""
351353
Defines an tensor value into the XNNGraph
@@ -365,6 +367,9 @@ def define_tensor(
365367
swap will happen before converting to nhwc.
366368
quant_params: Quantization meta data for this tensor, None if it is not quantized
367369
fp32_static_weights: XNN_FLAG_FP32_STATIC_WEIGHTS for fp16 conv
370+
swap_in_out_for_transpose_weights: bool to indicate whether tensor shape should be
371+
permuted and reshape from (inc, oc/groups, height, width) to (oc, inc/groups, height, width)
372+
groups: number of groups for swap_in_out_for_transpose_weights
368373
"""
369374

370375
if tensor in vals_to_ids:
@@ -397,12 +402,16 @@ def define_tensor(
397402
swap_nc_for_depthwise_weights,
398403
quant_params,
399404
fp32_static_weights,
405+
swap_in_out_for_transpose_weights,
406+
groups,
400407
)
401408

402409
# convert tensor shape must reflect memory format, default is contiguous, so
403410
# only permute shape if we are converting the tensor to nhwc format
404411
if swap_nc_for_depthwise_weights:
405412
dims = [dims[1], dims[0]] + dims[2:]
413+
if swap_in_out_for_transpose_weights:
414+
dims = [dims[1] * groups, dims[0] // groups] + dims[2:]
406415
if convert_to_nhwc:
407416
check_or_raise(len(dims) == 4, "Converting to nhwc requires 4d tensor")
408417
dims = [dims[i] for i in PERM_NCHW_TO_NHWC]
@@ -433,6 +442,14 @@ def define_tensor(
433442
else:
434443
assert f"Unsupported weight per channel quantization axis for depthwise conv2d: {quant_params.axis}, expecting 0."
435444

445+
if swap_in_out_for_transpose_weights and (
446+
quant_params and quant_params.per_channel
447+
):
448+
if quant_params.axis == 0:
449+
quant_params.axis = len(dims) - 1
450+
else:
451+
assert f"Unsupported weight per channel quantization axis for conv_transpose2d: {quant_params.axis}, expecting 0."
452+
436453
# Serialize tensor value
437454
ser_val = (
438455
XValue(xvalue_union=tvalue)
@@ -495,6 +512,8 @@ def get_serialized_buffer_index(
495512
swap_nc_for_depthwise_weights: bool,
496513
quant_params: Optional[QuantParams],
497514
fp32_static_weights: bool = False,
515+
swap_in_out_for_transpose_weights: bool = False,
516+
groups: int = 1,
498517
) -> int:
499518
"""
500519
If tensor holds some constant data, serialize it and return the
@@ -546,6 +565,16 @@ def get_serialized_buffer_index(
546565
dims=((1, 0) + tuple(range(2, const_val.dim())))
547566
).contiguous()
548567

568+
if swap_in_out_for_transpose_weights:
569+
shape = const_val.shape
570+
const_val = const_val.reshape(
571+
(groups, const_val.shape[0] // groups) + const_val.shape[1:]
572+
)
573+
const_val = const_val.permute((0, 2, 1) + tuple(range(3, const_val.dim())))
574+
const_val = const_val.reshape(
575+
(shape[1] * groups, shape[0] // groups) + shape[2:]
576+
).contiguous()
577+
549578
if convert_to_nhwc:
550579
const_val = const_val.to(memory_format=torch.channels_last)
551580

backends/xnnpack/operators/op_conv2d.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from executorch.backends.xnnpack.operators.quant_params import QuantParams
1717
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
1818
XNNConv2d,
19+
XNNConvTranspose2d,
1920
XNNDepthwiseConv2d,
2021
XNNGraph,
2122
XNode,
@@ -52,21 +53,31 @@ def define_node(
5253
) # NHWC input
5354
kwargs["input1_id"] = vals_to_ids[get_input_node(node, 0)]
5455

55-
# filter shape for pytorch convolution is (oc, inc/groups, height, width)
56-
# shape for xnnpack convolution is (oc, height, width, inc/groups), to convert
57-
# to the proper shape, this is essentially a NCHW to NHWC conversion
56+
# filter shape for pytorch convolution is (oc, inc/groups, height, width),
57+
# filter shape for pytorch transpose convolution is (inc, oc/groups, height, width),
58+
# shape for xnnpack convolution is (oc, height, width, inc/groups),
59+
# shape for xnnpack transpose convolution is (oc, height, width, inc/groups),
60+
# to convert to the proper shape, this is essentially a NCHW to NHWC conversion
5861
kernel_node = get_input_node(node, 1)
5962
kernel_shape = get_shape(kernel_node)
6063
groups = cast(int, node.args[8])
61-
group_input_channels = kernel_shape[1]
62-
group_output_channels = int(kernel_shape[0] / groups)
64+
is_transpose = node.args[6]
65+
66+
if is_transpose:
67+
group_input_channels = int(kernel_shape[0] / groups)
68+
group_output_channels = kernel_shape[1]
69+
else:
70+
group_input_channels = kernel_shape[1]
71+
group_output_channels = int(kernel_shape[0] / groups)
6372

6473
# XNNPACK expects the kernel's N and C dimensions to be swapped for
6574
# Depthwise Convolution, which occurs under the following conditions:
6675
# 1) groups = input_channels (i.e. group_input_channels = 1)
6776
# 2) output_channels is a positive integer multiple of input channels
68-
is_depthwise_conv = (group_input_channels == 1) and (
69-
group_output_channels % group_input_channels == 0
77+
is_depthwise_conv = (
78+
(group_input_channels == 1)
79+
and (group_output_channels % group_input_channels == 0)
80+
and not is_transpose
7081
)
7182
weight_quant_params = QuantParams.from_weights(
7283
kernel_node, self._exported_program
@@ -81,6 +92,8 @@ def define_node(
8192
swap_nc_for_depthwise_weights=is_depthwise_conv,
8293
quant_params=weight_quant_params,
8394
fp32_static_weights=fp32_static_weights,
95+
swap_in_out_for_transpose_weights=is_transpose,
96+
groups=groups,
8497
)
8598
kwargs["filter_id"] = vals_to_ids[get_input_node(node, 1)]
8699

@@ -120,10 +133,6 @@ def define_node(
120133
if len(padding) == 1:
121134
padding = padding + padding
122135

123-
# args[6] = transposed
124-
check_or_raise(
125-
not cast(bool, node.args[6]), "No support for transposed convolution"
126-
)
127136
# args[7] = output padding
128137
check_or_raise(
129138
all(out_pad == 0 for out_pad in cast(List[int], node.args[7])),
@@ -152,6 +161,8 @@ def define_node(
152161

153162
if is_depthwise_conv:
154163
conv_node_type = XNNDepthwiseConv2d
164+
elif is_transpose:
165+
conv_node_type = XNNConvTranspose2d
155166
else:
156167
conv_node_type = XNNConv2d
157168

backends/xnnpack/partition/config/gemm_configs.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import cast, List, Optional, Tuple
1010

1111
import torch
12+
from executorch.backends.xnnpack.operators.quant_params import QuantParams
1213
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
1314
ConfigPrecisionType,
1415
XNNPartitionerConfig,
@@ -327,11 +328,23 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
327328
why(node, "Only support 1D + 2D Conv")
328329
return False # Only support 1D + 2D Conv
329330

330-
transposed = cast(bool, node.args[6])
331-
if transposed:
332-
why(node, "Transposed Conv is not supported")
333-
return False # Currently don't support transposed conv
331+
kernel_node = get_input_node(node, 1)
332+
weight_quant_params = QuantParams.from_weights(kernel_node, ep)
334333

334+
is_transpose = node.args[6]
335+
groups = cast(int, node.args[8])
336+
if (
337+
is_transpose
338+
and weight_quant_params is not None
339+
and weight_quant_params.per_channel
340+
and groups > 1
341+
):
342+
why(
343+
node,
344+
"XNNPACK does not support per input channel quantization"
345+
"for transpose convolutions with groups > 1",
346+
)
347+
return False
335348
return True
336349

337350
def supported_precision_types(self):

backends/xnnpack/partition/configs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
torch.nn.BatchNorm2d,
7474
torch.nn.BatchNorm1d,
7575
torch.nn.Conv2d,
76+
torch.nn.ConvTranspose2d,
7677
torch.nn.Linear,
7778
torch.nn.functional.linear,
7879
torch.nn.PReLU, # Without this, the PReLU weight becomes not a get_attr
@@ -130,8 +131,11 @@
130131
torch.nn.functional.conv1d,
131132
torch.ao.nn.quantized.reference.modules.conv.Conv1d,
132133
torch.nn.Conv2d,
134+
torch.nn.ConvTranspose2d,
133135
torch.nn.functional.conv2d,
136+
torch.nn.functional.conv_transpose2d,
134137
torch.ao.nn.quantized.reference.modules.conv.Conv2d,
138+
torch.ao.nn.quantized.reference.modules.conv.ConvTranspose2d,
135139
torch.nn.BatchNorm1d,
136140
torch.nn.BatchNorm2d,
137141
]

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -979,6 +979,54 @@ Error defineConv2dNode(
979979
return Error::Ok;
980980
}
981981

982+
/*
983+
Define serialized conv_transpose2d node into the subgraph, using the remapped
984+
ids to map the serialized ids, to the new ids generated when defining the tensor
985+
value
986+
*/
987+
Error defineConvTranspose2dNode(
988+
xnn_subgraph_t subgraph_ptr,
989+
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
990+
const NodePtr node,
991+
const fb_xnnpack::XNNGraph* graph) noexcept {
992+
MAYBE_UNUSED(graph);
993+
auto graph_node = node->xnode_union_as_XNNConvTranspose2d();
994+
995+
std::pair<float, float> min_max = getOutputMinMax(node);
996+
xnn_status status = xnn_define_deconvolution_2d(
997+
subgraph_ptr,
998+
graph_node->padding_top(),
999+
graph_node->padding_right(),
1000+
graph_node->padding_bottom(),
1001+
graph_node->padding_left(),
1002+
graph_node->adjustment_height(),
1003+
graph_node->adjustment_width(),
1004+
graph_node->kernel_height(),
1005+
graph_node->kernel_width(),
1006+
graph_node->subsampling_height(),
1007+
graph_node->subsampling_width(),
1008+
graph_node->dilation_height(),
1009+
graph_node->dilation_width(),
1010+
graph_node->groups(),
1011+
graph_node->group_input_channels(),
1012+
graph_node->group_output_channels(),
1013+
min_max.first,
1014+
min_max.second,
1015+
remapped_ids.at(graph_node->input1_id()),
1016+
remapped_ids.at(graph_node->filter_id()),
1017+
remapped_ids.at(graph_node->bias_id()),
1018+
remapped_ids.at(graph_node->output_id()),
1019+
graph_node->flags());
1020+
ET_CHECK_OR_RETURN_ERROR(
1021+
status == xnn_status_success,
1022+
Internal,
1023+
"Failed to create deconvolution node %i with code: %s",
1024+
node->debug_handle(),
1025+
xnn_status_to_string(status));
1026+
1027+
return Error::Ok;
1028+
}
1029+
9821030
/*
9831031
Define serialized maxpool2d node into the subgraph, using the remapped ids
9841032
to map the serialized ids, to the new ids generated when defining the
@@ -1840,6 +1888,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
18401888
_DEFINE(StaticTranspose)
18411889
_DEFINE(Clamp)
18421890
_DEFINE(Conv2d)
1891+
_DEFINE(ConvTranspose2d)
18431892
_DEFINE(Div)
18441893
_DEFINE(StaticResizeBilinear2D)
18451894
_DEFINE(StaticConstantPad)

backends/xnnpack/serialization/runtime_schema.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ union XNodeUnion {
137137
XNNScaledDotProductAttention,
138138
XNNBatchMatrixMultiply: _XNNNode2x1,
139139
XNNConcatenate5: _XNNCat,
140+
XNNConvTranspose2d: _XNNNodeConv,
140141
}
141142

142143
union XValueUnion {

backends/xnnpack/serialization/schema.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ union XNodeUnion {
133133
XNNScaledDotProductAttention,
134134
XNNBatchMatrixMultiply: _XNNNode2x1,
135135
XNNConcatenate5: _XNNCat,
136+
XNNConvTranspose2d: _XNNNodeConv,
136137
}
137138

138139
union XValueUnion {

backends/xnnpack/serialization/xnnpack_graph_schema.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,11 @@ class XNNConv2d(XNNNodeConv):
103103
pass
104104

105105

106+
@dataclass
107+
class XNNConvTranspose2d(XNNNodeConv):
108+
pass
109+
110+
106111
@dataclass
107112
class XNNAdd(XNNNode2x1):
108113
pass
@@ -336,6 +341,7 @@ class XNNScaledDotProductAttention:
336341
XNNStaticTranspose,
337342
XNNClamp,
338343
XNNConv2d,
344+
XNNConvTranspose2d,
339345
XNNDiv,
340346
XNNStaticResizeBilinear2D,
341347
XNNStaticConstantPad,

0 commit comments

Comments
 (0)