diff --git a/backends/xnnpack/_passes/fuse_activation_pass.py b/backends/xnnpack/_passes/fuse_activation_pass.py index 289e2b03fdb..eeaf0f2a84f 100644 --- a/backends/xnnpack/_passes/fuse_activation_pass.py +++ b/backends/xnnpack/_passes/fuse_activation_pass.py @@ -68,6 +68,10 @@ def call(self, graph_module: torch.fx.GraphModule): preceding_op.op == "call_function" and preceding_op.target in self.FUSEABLE_OPS ): + # Check that current activation is the only user of the preceding op + # so that we can fuse the activation into the preceding op + if len(preceding_op.users) > 1: + continue # Delete activation, and embed metadata into preceding op output_min_max = self.get_output_min_max_from_activation( activation_node diff --git a/backends/xnnpack/_passes/fuse_batch_norm_with_conv.py b/backends/xnnpack/_passes/fuse_batch_norm_with_conv.py index 8014edfb1c3..b0f4779eb4c 100644 --- a/backends/xnnpack/_passes/fuse_batch_norm_with_conv.py +++ b/backends/xnnpack/_passes/fuse_batch_norm_with_conv.py @@ -82,6 +82,7 @@ def call(self, graph_module: torch.fx.GraphModule): # as an arg) eps = bn.args[-1] + is_transpose = conv.args[6] # Compute the updated weight and bias after fusing conv op # with batchnorm op. fused_weight, fused_bias = fuse_conv_bn_weights( @@ -92,6 +93,7 @@ def call(self, graph_module: torch.fx.GraphModule): eps, bn_weight, bn_bias, + is_transpose, ) # Modify the graph by updating the weight and bias of conv op diff --git a/backends/xnnpack/_passes/tag_implicit_q_dq_pass.py b/backends/xnnpack/_passes/tag_implicit_q_dq_pass.py index edbe9b44dcd..3c6345e28a2 100644 --- a/backends/xnnpack/_passes/tag_implicit_q_dq_pass.py +++ b/backends/xnnpack/_passes/tag_implicit_q_dq_pass.py @@ -83,11 +83,16 @@ def is_dynamically_quantized(self, node: torch.fx.Node) -> bool: return is_dynamic_qdq(node) def is_supported_quant_op(self, node: torch.fx.Node) -> bool: - return ( - node.op == "call_function" - and cast(torch._ops.OpOverload, node.target).name() - in SUPPORTED_IMPLICIT_Q_DQ_OP_NAMES_SET - ) + if node.op != "call_function": + return False + + op_name = cast(torch._ops.OpOverload, node.target).name() + + # Weight and Input should both be quantized + if op_name == exir_ops.edge.aten.convolution.default.name(): + return is_dequant(node.args[1]) + + return op_name in SUPPORTED_IMPLICIT_Q_DQ_OP_NAMES_SET def is_supported_quant_module(self, node: torch.fx.Node) -> bool: is_supported = ( diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py index 018ce1e568f..e0871089ec8 100644 --- a/backends/xnnpack/operators/node_visitor.py +++ b/backends/xnnpack/operators/node_visitor.py @@ -337,15 +337,16 @@ def _check_per_channel_group_params( # For now group quantization is only supported for 4b weights assert quant_params.is_qc4w, "Only 4b group quantization is supported" - def define_tensor( + def define_tensor( # noqa: C901 self, tensor: torch.fx.Node, xnn_graph: XNNGraph, vals_to_ids: Dict[torch.fx.Node, int], convert_to_nhwc: bool = False, - swap_nc_for_depthwise_weights: bool = False, + swap_in_out_for_weights: bool = False, quant_params: Optional[QuantParams] = None, fp32_static_weights: bool = False, + groups: int = 1, ) -> None: """ Defines an tensor value into the XNNGraph @@ -357,16 +358,21 @@ def define_tensor( their corresponding ids in XNNGraph convert_to_nhwc: bool to indicate whether tensor shape should be permuted to reflect the nhwc memory format. - swap_nc_for_depthwise_weights: bool to indicate whether tensor shape - should be permuted such that the N and C dimensions are - swapped, which should be used for depthwise convolution + swap_in_out_for_weights: bool to indicate whether tensor shape should be + permuted and reshape from (inc, oc/groups, height, width) to (oc, inc/groups, height, width) + , which should be used for depthwise/transpose convolution weights. This is only valid for tensors which hold constant data. If used along with convert_to_nhwc, this swap will happen before converting to nhwc. quant_params: Quantization meta data for this tensor, None if it is not quantized fp32_static_weights: XNN_FLAG_FP32_STATIC_WEIGHTS for fp16 conv + groups: number of groups for swap_in_out_for_weights """ + assert ( + swap_in_out_for_weights or groups == 1 + ), "groups is option for swap_in_out_for_weights" + if tensor in vals_to_ids: return @@ -394,15 +400,16 @@ def define_tensor( xnn_graph, vals_to_ids, convert_to_nhwc, - swap_nc_for_depthwise_weights, + swap_in_out_for_weights, quant_params, fp32_static_weights, + groups, ) # convert tensor shape must reflect memory format, default is contiguous, so # only permute shape if we are converting the tensor to nhwc format - if swap_nc_for_depthwise_weights: - dims = [dims[1], dims[0]] + dims[2:] + if swap_in_out_for_weights: + dims = [dims[1] * groups, dims[0] // groups] + dims[2:] if convert_to_nhwc: check_or_raise(len(dims) == 4, "Converting to nhwc requires 4d tensor") dims = [dims[i] for i in PERM_NCHW_TO_NHWC] @@ -422,16 +429,16 @@ def define_tensor( ) # Override the quant params axis since we have - # updated the weights for depthwise, with that the out_channels dim + # updated the weights for depthwise/ transposed_conv2d, with that the out_channels dim # will be dims[3] instead of dims[0]. Let's update the per_channel # quant axis to match the new weight tensor before serializing - if swap_nc_for_depthwise_weights and ( - quant_params and quant_params.per_channel - ): + if swap_in_out_for_weights and (quant_params and quant_params.per_channel): if quant_params.axis == 0: quant_params.axis = len(dims) - 1 + elif quant_params.axis == 1: + quant_params.axis = 0 else: - assert f"Unsupported weight per channel quantization axis for depthwise conv2d: {quant_params.axis}, expecting 0." + assert f"Unsupported weight per channel quantization axis for depthwise conv2d / conv_transpose2d : {quant_params.axis}, expecting 0 / 1." # Serialize tensor value ser_val = ( @@ -492,9 +499,10 @@ def get_serialized_buffer_index( xnn_graph: XNNGraph, vals_to_ids: Dict[torch.fx.Node, int], convert_to_nhwc: bool, - swap_nc_for_depthwise_weights: bool, + swap_in_out_for_weights: bool, quant_params: Optional[QuantParams], fp32_static_weights: bool = False, + groups: int = 1, ) -> int: """ If tensor holds some constant data, serialize it and return the @@ -507,24 +515,30 @@ def get_serialized_buffer_index( their corresponding ids in XNNGraph convert_to_nhwc: bool to indicate whether tensor shape should be permuted to reflect the nhwc memory format. - swap_nc_for_depthwise_weights: bool to indicate whether tensor shape - should be permuted such that the N and C dimensions are - swapped, which should be used for depthwise convolution + swap_in_out_for_weights: bool to indicate whether tensor shape should be + permuted and reshape from (inc, oc/groups, height, width) to (oc, inc/groups, height, width) + , which should be used for depthwise/transpose convolution weights. This is only valid for tensors which hold constant data. If used along with convert_to_nhwc, this swap will happen before converting to nhwc. quant_params: Quantization meta data for this tensor, None if it is not quantize fp32_static_weights: bool to indicate whether tensor is fp32 static weights + groups: groups for swap_in_out_for_weights Returns: buffer_idx: idx of the serialized data. 0 If not associated constant data """ + + assert ( + swap_in_out_for_weights or groups == 1 + ), "groups is option for swap_in_out_for_weights" + # The get_attr node is the input to quant_params. get_attr_node = tensor if quant_params is None else quant_params.q_input if not is_param_node(self.exported_program, get_attr_node): check_or_raise( - not swap_nc_for_depthwise_weights, + not swap_in_out_for_weights, "Swapping N and C dimensions is only valid for constant data tensors", ) return 0 @@ -541,9 +555,16 @@ def get_serialized_buffer_index( # ensure that the const is fp32 const_val = const_val.to(dtype=torch.float32).contiguous() - if swap_nc_for_depthwise_weights: - const_val = const_val.permute( - dims=((1, 0) + tuple(range(2, const_val.dim()))) + if swap_in_out_for_weights: + # Permute and reshape the tensor from (inc, oc/groups, height, width) to (oc, inc/groups, height, width) + # which should be used for depthwise/transpose convolution weights for XNNPACK + shape = const_val.shape + const_val = const_val.reshape( + (groups, const_val.shape[0] // groups) + const_val.shape[1:] + ) + const_val = const_val.permute((0, 2, 1) + tuple(range(3, const_val.dim()))) + const_val = const_val.reshape( + (shape[1] * groups, shape[0] // groups) + shape[2:] ).contiguous() if convert_to_nhwc: diff --git a/backends/xnnpack/operators/op_conv2d.py b/backends/xnnpack/operators/op_conv2d.py index 62c30c010a1..1272f1b5250 100644 --- a/backends/xnnpack/operators/op_conv2d.py +++ b/backends/xnnpack/operators/op_conv2d.py @@ -16,6 +16,7 @@ from executorch.backends.xnnpack.operators.quant_params import QuantParams from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( XNNConv2d, + XNNConvTranspose2d, XNNDepthwiseConv2d, XNNGraph, XNode, @@ -52,35 +53,57 @@ def define_node( ) # NHWC input kwargs["input1_id"] = vals_to_ids[get_input_node(node, 0)] - # filter shape for pytorch convolution is (oc, inc/groups, height, width) - # shape for xnnpack convolution is (oc, height, width, inc/groups), to convert - # to the proper shape, this is essentially a NCHW to NHWC conversion + # filter shape for pytorch convolution is (oc, inc/groups, height, width), + # filter shape for pytorch transpose convolution is (inc, oc/groups, height, width), + # shape for xnnpack convolution is (oc, height, width, inc/groups), + # shape for xnnpack transpose convolution is (oc, height, width, inc/groups), + # to convert to the proper shape, this is essentially a NCHW to NHWC conversion kernel_node = get_input_node(node, 1) kernel_shape = get_shape(kernel_node) groups = cast(int, node.args[8]) - group_input_channels = kernel_shape[1] - group_output_channels = int(kernel_shape[0] / groups) + is_transpose = node.args[6] + + if is_transpose: + group_input_channels = int(kernel_shape[0] / groups) + group_output_channels = kernel_shape[1] + else: + group_input_channels = kernel_shape[1] + group_output_channels = int(kernel_shape[0] / groups) # XNNPACK expects the kernel's N and C dimensions to be swapped for # Depthwise Convolution, which occurs under the following conditions: # 1) groups = input_channels (i.e. group_input_channels = 1) # 2) output_channels is a positive integer multiple of input channels - is_depthwise_conv = (group_input_channels == 1) and ( - group_output_channels % group_input_channels == 0 + is_depthwise_conv = ( + (group_input_channels == 1) + and (group_output_channels % group_input_channels == 0) + and not is_transpose ) weight_quant_params = QuantParams.from_weights( kernel_node, self._exported_program ) fp32_static_weights = kernel_node.meta["val"].dtype == torch.float16 + if weight_quant_params is not None and weight_quant_params.per_channel: + if is_transpose: + check_or_raise( + weight_quant_params.axis == 1 and groups == 1, + "XNNPACK currently only supports per output channel quantization with groups == 1 for transpose convolutions", + ) + elif is_depthwise_conv: + check_or_raise( + weight_quant_params.axis == 0, + "XNNPACK currently only supports per input channel quantization for depthwise convolutions", + ) self.define_tensor( kernel_node, xnn_graph, vals_to_ids, convert_to_nhwc=True, - swap_nc_for_depthwise_weights=is_depthwise_conv, + swap_in_out_for_weights=is_depthwise_conv or is_transpose, quant_params=weight_quant_params, fp32_static_weights=fp32_static_weights, + groups=groups if is_transpose else 1, ) kwargs["filter_id"] = vals_to_ids[get_input_node(node, 1)] @@ -120,10 +143,6 @@ def define_node( if len(padding) == 1: padding = padding + padding - # args[6] = transposed - check_or_raise( - not cast(bool, node.args[6]), "No support for transposed convolution" - ) # args[7] = output padding check_or_raise( all(out_pad == 0 for out_pad in cast(List[int], node.args[7])), @@ -152,6 +171,8 @@ def define_node( if is_depthwise_conv: conv_node_type = XNNDepthwiseConv2d + elif is_transpose: + conv_node_type = XNNConvTranspose2d else: conv_node_type = XNNConv2d diff --git a/backends/xnnpack/partition/config/gemm_configs.py b/backends/xnnpack/partition/config/gemm_configs.py index e19a102ee4e..dbdcd07a668 100644 --- a/backends/xnnpack/partition/config/gemm_configs.py +++ b/backends/xnnpack/partition/config/gemm_configs.py @@ -9,6 +9,7 @@ from typing import cast, List, Optional, Tuple import torch +from executorch.backends.xnnpack.operators.quant_params import QuantParams from executorch.backends.xnnpack.partition.config.xnnpack_config import ( ConfigPrecisionType, XNNPartitionerConfig, @@ -317,7 +318,7 @@ def __init__(self, **kwargs): def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: """ - Currently we have no support for convolution 3d and transposed convolution + Currently we have no support for convolution 3d """ if not super().check_constraints(node, ep): return False @@ -327,11 +328,24 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: why(node, "Only support 1D + 2D Conv") return False # Only support 1D + 2D Conv - transposed = cast(bool, node.args[6]) - if transposed: - why(node, "Transposed Conv is not supported") - return False # Currently don't support transposed conv + kernel_node = get_input_node(node, 1) + weight_quant_params = QuantParams.from_weights(kernel_node, ep) + is_transpose = node.args[6] + groups = cast(int, node.args[8]) + + if ( + is_transpose + and weight_quant_params is not None + and weight_quant_params.per_channel + and (groups > 1 or weight_quant_params.axis != 1) + ): + why( + node, + "XNNPACK does not support per input channel quantization" + "for transpose convolutions with groups > 1", + ) + return False return True def supported_precision_types(self): diff --git a/backends/xnnpack/partition/configs.py b/backends/xnnpack/partition/configs.py index 2629695518b..ad4af24d3fc 100644 --- a/backends/xnnpack/partition/configs.py +++ b/backends/xnnpack/partition/configs.py @@ -73,6 +73,7 @@ torch.nn.BatchNorm2d, torch.nn.BatchNorm1d, torch.nn.Conv2d, + torch.nn.ConvTranspose2d, torch.nn.Linear, torch.nn.functional.linear, torch.nn.PReLU, # Without this, the PReLU weight becomes not a get_attr diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 3d4d2e68219..1a01ce9f01a 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -979,6 +979,54 @@ Error defineConv2dNode( return Error::Ok; } +/* +Define serialized conv_transpose2d node into the subgraph, using the remapped +ids to map the serialized ids, to the new ids generated when defining the tensor +value +*/ +Error defineConvTranspose2dNode( + xnn_subgraph_t subgraph_ptr, + const std::unordered_map& remapped_ids, + const NodePtr node, + const fb_xnnpack::XNNGraph* graph) noexcept { + MAYBE_UNUSED(graph); + auto graph_node = node->xnode_union_as_XNNConvTranspose2d(); + + std::pair min_max = getOutputMinMax(node); + xnn_status status = xnn_define_deconvolution_2d( + subgraph_ptr, + graph_node->padding_top(), + graph_node->padding_right(), + graph_node->padding_bottom(), + graph_node->padding_left(), + graph_node->adjustment_height(), + graph_node->adjustment_width(), + graph_node->kernel_height(), + graph_node->kernel_width(), + graph_node->subsampling_height(), + graph_node->subsampling_width(), + graph_node->dilation_height(), + graph_node->dilation_width(), + graph_node->groups(), + graph_node->group_input_channels(), + graph_node->group_output_channels(), + min_max.first, + min_max.second, + remapped_ids.at(graph_node->input1_id()), + remapped_ids.at(graph_node->filter_id()), + remapped_ids.at(graph_node->bias_id()), + remapped_ids.at(graph_node->output_id()), + graph_node->flags()); + ET_CHECK_OR_RETURN_ERROR( + status == xnn_status_success, + Internal, + "Failed to create deconvolution node %i with code: %s", + node->debug_handle(), + xnn_status_to_string(status)); + + return Error::Ok; +} + /* Define serialized maxpool2d node into the subgraph, using the remapped ids to map the serialized ids, to the new ids generated when defining the @@ -1840,6 +1888,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) { _DEFINE(StaticTranspose) _DEFINE(Clamp) _DEFINE(Conv2d) + _DEFINE(ConvTranspose2d) _DEFINE(Div) _DEFINE(StaticResizeBilinear2D) _DEFINE(StaticConstantPad) diff --git a/backends/xnnpack/serialization/runtime_schema.fbs b/backends/xnnpack/serialization/runtime_schema.fbs index 0c6ee86912f..8ba346d9bc0 100644 --- a/backends/xnnpack/serialization/runtime_schema.fbs +++ b/backends/xnnpack/serialization/runtime_schema.fbs @@ -137,6 +137,7 @@ union XNodeUnion { XNNScaledDotProductAttention, XNNBatchMatrixMultiply: _XNNNode2x1, XNNConcatenate5: _XNNCat, + XNNConvTranspose2d: _XNNNodeConv, } union XValueUnion { diff --git a/backends/xnnpack/serialization/schema.fbs b/backends/xnnpack/serialization/schema.fbs index 45f12484635..81263825ff5 100644 --- a/backends/xnnpack/serialization/schema.fbs +++ b/backends/xnnpack/serialization/schema.fbs @@ -133,6 +133,7 @@ union XNodeUnion { XNNScaledDotProductAttention, XNNBatchMatrixMultiply: _XNNNode2x1, XNNConcatenate5: _XNNCat, + XNNConvTranspose2d: _XNNNodeConv, } union XValueUnion { diff --git a/backends/xnnpack/serialization/xnnpack_graph_schema.py b/backends/xnnpack/serialization/xnnpack_graph_schema.py index ca0fc60bdc0..7c23a75507d 100644 --- a/backends/xnnpack/serialization/xnnpack_graph_schema.py +++ b/backends/xnnpack/serialization/xnnpack_graph_schema.py @@ -103,6 +103,11 @@ class XNNConv2d(XNNNodeConv): pass +@dataclass +class XNNConvTranspose2d(XNNNodeConv): + pass + + @dataclass class XNNAdd(XNNNode2x1): pass @@ -336,6 +341,7 @@ class XNNScaledDotProductAttention: XNNStaticTranspose, XNNClamp, XNNConv2d, + XNNConvTranspose2d, XNNDiv, XNNStaticResizeBilinear2D, XNNStaticConstantPad, diff --git a/backends/xnnpack/test/ops/test_conv2d.py b/backends/xnnpack/test/ops/test_conv2d.py index d88f88724bd..533b9ab90cf 100644 --- a/backends/xnnpack/test/ops/test_conv2d.py +++ b/backends/xnnpack/test/ops/test_conv2d.py @@ -9,8 +9,19 @@ from typing import Optional import torch + +try: + import executorch.extension.pybindings.portable_lib # noqa[F401] + import executorch.kernels.quantized # noqa[F401] + + has_quantized_ops = True +except: + has_quantized_ops = False + from executorch.backends.xnnpack.test.test_xnnpack_utils import randomize_bn from executorch.backends.xnnpack.test.tester import Quantize, Tester + +from executorch.exir.dialects._ops import ops as exir_ops from torch.ao.quantization.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, ) @@ -33,6 +44,7 @@ def __init__( width=8, height=8, dtype=torch.float, + transpose=False, ): super().__init__() self.batches = batches @@ -41,7 +53,9 @@ def __init__( self.in_channels = in_channels self.dtype = dtype - self.conv = torch.nn.Conv2d( + op = torch.nn.ConvTranspose2d if transpose else torch.nn.Conv2d + self.transpose = transpose + self.conv = op( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, @@ -65,16 +79,18 @@ def get_inputs(self): class Conv2dSeq(torch.nn.Module): - def __init__(self): + def __init__(self, transpose=False): super().__init__() - self.first = torch.nn.Conv2d( + op = torch.nn.ConvTranspose2d if transpose else torch.nn.Conv2d + self.transpose = transpose + self.first = op( in_channels=1, out_channels=3, kernel_size=(3, 3), padding=1, bias=False, ) - self.second = torch.nn.Conv2d( + self.second = op( in_channels=3, out_channels=2, kernel_size=(3, 3), @@ -91,9 +107,11 @@ def get_inputs(self): class Conv2dBatchNorm(torch.nn.Module): - def __init__(self): + def __init__(self, transpose=False): super().__init__() - self.conv1 = torch.nn.Conv2d( + op = torch.nn.ConvTranspose2d if transpose else torch.nn.Conv2d + self.transpose = transpose + self.conv1 = op( 2, 2, (2, 2), @@ -103,7 +121,7 @@ def __init__(self): ) self.bn = randomize_bn(2) self.hardtanh = torch.nn.Hardtanh() - self.conv2 = torch.nn.Conv2d( + self.conv2 = op( 2, 2, (2, 2), @@ -126,9 +144,11 @@ def get_inputs(self): class Conv2dPermute(torch.nn.Module): - def __init__(self, permute_order): + def __init__(self, permute_order, transpose=False): super().__init__() - self.conv = torch.nn.Conv2d( + op = torch.nn.ConvTranspose2d if transpose else torch.nn.Conv2d + self.transpose = transpose + self.conv = op( 2, 2, (2, 2), @@ -154,6 +174,8 @@ def _test( quant_config: Optional[QuantizationConfig] = None, conv_count=1, dtype: torch.dtype = torch.float, + check_quantized=True, + delegated=True, ): # pyre-fixme[29]: `Union[torch._tensor.Tensor, # torch.nn.modules.module.Module]` is not a function. @@ -161,55 +183,87 @@ def _test( if quant_config is not None: tester = tester.quantize(Quantize(quantization_config=quant_config)) - tester.check(["torch.ops.quantized_decomposed"]) - - ( - tester.export() - .check_count({"torch.ops.aten.conv2d": conv_count}) - .to_edge_transform_and_lower() - .check_not(["executorch_exir_dialects_edge__ops_aten_convolution_default"]) - .check_not( - [ - "executorch_exir_dialects_edge__ops__native_batch_norm_legit_no_training_default" - ] - ) - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .to_executorch() - .serialize() - .run_method_and_compare_outputs(qtol=1) + if check_quantized: + tester.check(["torch.ops.quantized_decomposed"]) + + op = ( + "torch.ops.aten.conv2d" + if not m.transpose + else "torch.ops.aten.conv_transpose2d" ) + (tester.export().check_count({op: conv_count}).to_edge_transform_and_lower()) + + if delegated: + ( + tester.check_not( + ["executorch_exir_dialects_edge__ops_aten_convolution_default"] + ) + .check_not( + [ + "executorch_exir_dialects_edge__ops__native_batch_norm_legit_no_training_default" + ] + ) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .serialize() + .run_method_and_compare_outputs(qtol=1) + ) + else: + # need quantize ops when ops are not delegated to xnnpack + if has_quantized_ops: + ( + tester.to_executorch() + .serialize() + .run_method_and_compare_outputs(qtol=1) + ) + def test_fp16_conv2d(self) -> None: - for has_bias in (True, False): - self._test(Conv2d(bias=has_bias, dtype=torch.float16)) + for transpose in (True, False): + for has_bias in (True, False): + self._test( + Conv2d(bias=has_bias, dtype=torch.float16, transpose=transpose) + ) def test_fp32_conv2d(self) -> None: - for has_bias in (True, False): - self._test(Conv2d(bias=has_bias)) + for transpose in (True, False): + for has_bias in (True, False): + self._test(Conv2d(bias=has_bias, transpose=transpose)) def test_fp32_conv2d_permute(self) -> None: - for perm_order in list(itertools.permutations([0, 1, 2, 3])): - self._test(Conv2dPermute(perm_order)) + for transpose in (True, False): + for perm_order in list(itertools.permutations([0, 1, 2, 3])): + self._test(Conv2dPermute(perm_order, transpose=transpose)) def test_qs8_conv2d_test(self) -> None: - for has_bias in (True, False): - self._test( - Conv2d(bias=has_bias), quant_config=get_symmetric_quantization_config() - ) + for transpose in (True, False): + for has_bias in (True, False): + self._test( + Conv2d(bias=has_bias, transpose=transpose), + quant_config=get_symmetric_quantization_config(), + check_quantized=not transpose, # XNNPackQuantizer does not quantize this pattern yet + ) def test_qs8_conv2d_per_channel(self) -> None: - self._test( - Conv2d(), - quant_config=get_symmetric_quantization_config(is_per_channel=True), - ) + for transpose in (True, False): + self._test( + Conv2d(transpose=transpose), + quant_config=get_symmetric_quantization_config(is_per_channel=True), + check_quantized=not transpose, # XNNPackQuantizer does not quantize this pattern yet + ) def test_fp32_conv2d_seq(self) -> None: - self._test(Conv2dSeq(), conv_count=2) + for transpose in (True, False): + self._test(Conv2dSeq(transpose=transpose), conv_count=2) def test_qs8_conv2d_seq(self) -> None: - self._test( - Conv2dSeq(), conv_count=2, quant_config=get_symmetric_quantization_config() - ) + for transpose in (True, False): + self._test( + Conv2dSeq(transpose=transpose), + conv_count=2, + quant_config=get_symmetric_quantization_config(), + check_quantized=not transpose, # XNNPackQuantizer does not quantize this pattern yet + ) def test_fp32_conv2d_single_int_params(self): self._test( @@ -225,19 +279,29 @@ def test_fp32_conv2d_depthwise(self): # Depthwise Convolution Requirements: # - Groups must equal In Channels # - Out Channels must be a positive multiple of In Channels - self._test(Conv2d(groups=2, in_channels=2, out_channels=6)) + for transpose in (True, False): + + self._test( + Conv2d(groups=2, in_channels=2, out_channels=6, transpose=transpose) + ) def test_qs8_conv2d_depthwise(self): - self._test( - Conv2d(groups=2, in_channels=2, out_channels=6), - quant_config=get_symmetric_quantization_config(), - ) + for transpose in (True, False): + self._test( + Conv2d(groups=2, in_channels=2, out_channels=6, transpose=transpose), + quant_config=get_symmetric_quantization_config(), + check_quantized=not transpose, # XNNPackQuantizer does not quantize this pattern yet + ) def test_fp32_conv2d_bn(self): class Conv2dBatchNorm(torch.nn.Module): - def __init__(self, in_features: int, out_features: int, kernel_size): + def __init__( + self, in_features: int, out_features: int, kernel_size, transpose=False + ): super().__init__() - self.conv2d = torch.nn.Conv2d(in_features, out_features, kernel_size) + op = torch.nn.ConvTranspose2d if transpose else torch.nn.Conv2d + self.transpose = transpose + self.conv2d = op(in_features, out_features, kernel_size) self.bn = randomize_bn(out_features) self.in_features = in_features self.kernel_size = kernel_size @@ -257,7 +321,15 @@ def get_inputs(self): ), ) - self._test(Conv2dBatchNorm(in_features=2, out_features=2, kernel_size=(2, 2))) + for transpose in (True, False): + self._test( + Conv2dBatchNorm( + in_features=2, + out_features=2, + kernel_size=(2, 2), + transpose=transpose, + ) + ) def test_fp32_conv2d_bn_hardtanh_mean_sequence(self): """ @@ -267,9 +339,13 @@ def test_fp32_conv2d_bn_hardtanh_mean_sequence(self): """ class Conv2dBatchNormHardTanh(torch.nn.Module): - def __init__(self, in_channels: int, out_channels: int, kernel_size): + def __init__( + self, in_channels: int, out_channels: int, kernel_size, transpose=False + ): super().__init__() - self.conv = torch.nn.Conv2d( + op = torch.nn.ConvTranspose2d if transpose else torch.nn.Conv2d + self.transpose = transpose + self.conv = op( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, @@ -290,22 +366,32 @@ def forward(self, x): def get_inputs(self): return (torch.randn(2, self.in_channels, 8, 8),) - self._test( - Conv2dBatchNormHardTanh(in_channels=2, out_channels=1, kernel_size=(2, 2)) - ) + for transpose in (True, False): + self._test( + Conv2dBatchNormHardTanh( + in_channels=2, + out_channels=1, + kernel_size=(2, 2), + transpose=transpose, + ) + ) def test_qs8_conv2d_bn(self): - self._test( - Conv2dBatchNorm(), - quant_config=get_symmetric_quantization_config(), - conv_count=2, - ) + for transpose in (True, False): + self._test( + Conv2dBatchNorm(transpose=transpose), + quant_config=get_symmetric_quantization_config(), + conv_count=2, + check_quantized=not transpose, # XNNPackQuantizer does not quantize this pattern yet + ) def test_qs8_conv2d_relu(self): class ConvReLU(torch.nn.Module): - def __init__(self): + def __init__(self, transpose=False): super().__init__() - self.conv1 = torch.nn.Conv2d( + op = torch.nn.ConvTranspose2d if transpose else torch.nn.Conv2d + self.transpose = transpose + self.conv1 = op( 2, 2, (2, 2), @@ -323,10 +409,12 @@ def forward(self, x): def get_inputs(self): return (torch.randn(2, 2, 4, 4),) - self._test( - ConvReLU(), - quant_config=get_symmetric_quantization_config(), - ) + for transpose in (True, False): + self._test( + ConvReLU(transpose=transpose), + quant_config=get_symmetric_quantization_config(is_per_channel=True), + delegated=not transpose, + ) def test_qs8_conv2d_dw_relu(self): # Depthwise Convolution Requirements: @@ -343,9 +431,11 @@ def test_qs8_conv2d_dw_relu(self): batches = 1 class ModelConvReLU(torch.nn.Module): - def __init__(self): + def __init__(self, transpose=False): super().__init__() - self.conv1 = torch.nn.Conv2d( + op = torch.nn.ConvTranspose2d if transpose else torch.nn.Conv2d + self.transpose = transpose + self.conv1 = op( in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), @@ -365,23 +455,31 @@ def forward(self, x): def get_inputs(self): return (torch.randn(batches, in_channels, height, width) * 11,) - for per_channel_quant in (False, True): - model = ModelConvReLU() - self._test( - model, - quant_config=get_symmetric_quantization_config( - is_per_channel=per_channel_quant - ), - ) + for transpose in (True, False): + for per_channel_quant in (False, True): + if transpose and per_channel_quant: + continue + model = ModelConvReLU(transpose=transpose) + self._test( + model, + quant_config=get_symmetric_quantization_config( + is_per_channel=per_channel_quant + ), + # xnnpack only supports per output channel quantization for transposed convolutions + # XNNPackQuantizer quantizes per input channel currently + delegated=not transpose or not per_channel_quant, + ) def test_qs8_conv2d_relu_seq(self): class ConvReLUSeq(torch.nn.Module): - def __init__(self): + def __init__(self, transpose=False): super().__init__() + op = torch.nn.ConvTranspose2d if transpose else torch.nn.Conv2d + self.transpose = transpose self.model = torch.nn.Sequential( - torch.nn.Conv2d(1, 1, 1), + op(1, 1, 1), torch.nn.ReLU(), - torch.nn.Conv2d(1, 64, 1), + op(1, 64, 1), torch.nn.ReLU(), ) @@ -391,18 +489,21 @@ def forward(self, x): def get_inputs(self): return (torch.randn(1, 1, 1, 1),) - self._test( - ConvReLUSeq(), - quant_config=get_symmetric_quantization_config(), - conv_count=2, - ) + for transpose in (True, False): + self._test( + ConvReLUSeq(transpose=transpose), + quant_config=get_symmetric_quantization_config(), + conv_count=2, + ) def test_qs8_conv2d_relu_multi_users(self): class Conv2dReluMultiUsers(torch.nn.Module): - def __init__(self): + def __init__(self, transpose=False): super().__init__() - self.conv1 = torch.nn.Conv2d(1, 1, 1) - self.conv2 = torch.nn.Conv2d(1, 64, 1) + op = torch.nn.ConvTranspose2d if transpose else torch.nn.Conv2d + self.transpose = transpose + self.conv1 = op(1, 1, 1) + self.conv2 = op(1, 64, 1) self.relu = torch.nn.ReLU() def forward(self, x): @@ -414,8 +515,149 @@ def forward(self, x): def get_inputs(self): return (torch.randn(1, 1, 1, 1),) - self._test( - Conv2dReluMultiUsers(), - quant_config=get_symmetric_quantization_config(), - conv_count=2, - ) + for transpose in (True, False): + self._test( + Conv2dReluMultiUsers(transpose=transpose), + quant_config=get_symmetric_quantization_config(), + conv_count=2, + ) + + def test_qs8_conv_transpose_2d_quantize_per_channel(self): + class PerChannelConvTranspose2d(torch.nn.Module): + def __init__(self, input_channels, output_channels, groups, axis): + super().__init__() + self.input_channels = input_channels + self.output_channels = output_channels + self.axis = axis + self.groups = groups + self.transpose = True + self.weights = torch.nn.Parameter( + torch.randint( + low=-127, + high=127, + size=(input_channels, output_channels // groups, 4, 4), + ).type(dtype=torch.int8), + requires_grad=False, + ) + + axis_size = self.weights.shape[axis] + self.scale = torch.nn.Parameter(torch.ones(axis_size) * 0.12345) + self.zero_point = torch.nn.Parameter( + torch.zeros((axis_size,), dtype=torch.int64), requires_grad=False + ) + + def forward(self, x): + dequantize_weights = ( + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default( + self.weights, + self.scale, + self.zero_point, + self.axis, + -127, + 127, + torch.int8, + ) + ) + dequantize_input = ( + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default( + x, 0.12345, 0, -127, 127, torch.int8 + ) + ) + x = torch.nn.functional.conv_transpose2d( + dequantize_input, dequantize_weights, groups=self.groups + ) + + return exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default( + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default( + x, + 0.12345, + 0, + -127, + 127, + torch.int8, + ), + 0.12345, + 0, + -127, + 127, + torch.int8, + ) + + def get_inputs(self): + return ( + torch.randint( + low=-127, high=127, size=(3, self.input_channels, 4, 4) + ).type(dtype=torch.int8), + ) + + for groups in (1, 2): + for axis in (0, 1): + self._test( + PerChannelConvTranspose2d(3 * groups, 5 * groups, groups, axis), + quant_config=None, + conv_count=1, + delegated=axis == 1 + and groups + == 1, # xnnpack only support output channel axis quantization with groups == 1 + ) + + def test_qs8_conv_transpose_2d_dqd_f32_weights(self): + class TransposeConv2dDQDf32weights(torch.nn.Module): + def __init__(self, input_channels, output_channels, groups, axis): + super().__init__() + self.input_channels = input_channels + self.output_channels = output_channels + self.axis = axis + self.groups = groups + self.transpose = True + self.weights = torch.nn.Parameter( + torch.randn((input_channels, output_channels // groups, 4, 4)), + requires_grad=False, + ) + + axis_size = self.weights.shape[axis] + self.scale = torch.nn.Parameter(torch.ones(axis_size) * 0.12345) + self.zero_point = torch.nn.Parameter( + torch.zeros((axis_size,), dtype=torch.int64), requires_grad=False + ) + + def forward(self, x): + dequantize_input = ( + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default( + x, 0.12345, 0, -127, 127, torch.int8 + ) + ) + x = torch.nn.functional.conv_transpose2d( + dequantize_input, self.weights, groups=self.groups + ) + + return exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default( + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default( + x, + 0.12345, + 0, + -127, + 127, + torch.int8, + ), + 0.12345, + 0, + -127, + 127, + torch.int8, + ) + + def get_inputs(self): + return ( + torch.randint( + low=-127, high=127, size=(3, self.input_channels, 4, 4) + ).type(dtype=torch.int8), + ) + + for groups in (1, 2): + for axis in (0, 1): + self._test( + TransposeConv2dDQDf32weights(3 * groups, 5 * groups, groups, axis), + quant_config=None, + conv_count=1, + ) diff --git a/backends/xnnpack/test/passes/test_activation_fusion.py b/backends/xnnpack/test/passes/test_activation_fusion.py index a7964a3181c..5f340f61b2e 100644 --- a/backends/xnnpack/test/passes/test_activation_fusion.py +++ b/backends/xnnpack/test/passes/test_activation_fusion.py @@ -77,6 +77,20 @@ def test_activation_fusion_conv_relu(self): quantize=True, ) + def test_activation_fusion_conv_transpose_relu(self): + inputs = (torch.randn(1, 1, 8, 8),) + self._test_op_activation_case( + torch.nn.ConvTranspose2d(1, 1, (4, 4)), + exir_ops.edge.aten.convolution.default, + inputs, + ) + self._test_op_activation_case( + torch.nn.ConvTranspose2d(1, 1, (4, 4)), + exir_ops.edge.aten.convolution.default, + inputs, + quantize=True, + ) + def test_activation_fusion_linear_relu(self): inputs = (torch.randn(1, 1, 8, 8),) self._test_op_activation_case( @@ -153,6 +167,23 @@ def test_activation_fusion_conv_hardtanh(self): activation_name="executorch_exir_dialects_edge__ops_aten_hardtanh_default", ) + def test_activation_fusion_conv_transpose_hardtanh(self): + inputs = (torch.randn(1, 1, 8, 8),) + self._test_op_activation_case( + torch.nn.ConvTranspose2d(1, 1, (4, 4)), + exir_ops.edge.aten.convolution.default, + inputs, + activation=torch.nn.Hardtanh(min_val=-1.0, max_val=1.0), + activation_name="executorch_exir_dialects_edge__ops_aten_hardtanh_default", + ) + self._test_op_activation_case( + torch.nn.ConvTranspose2d(1, 1, (4, 4)), + exir_ops.edge.aten.convolution.default, + inputs, + activation=torch.nn.Hardtanh(min_val=-1.0, max_val=1.0), + activation_name="executorch_exir_dialects_edge__ops_aten_hardtanh_default", + ) + def test_activation_fusion_linear_hardtanh(self): inputs = (torch.randn(1, 1, 8, 8),) self._test_op_activation_case( diff --git a/backends/xnnpack/test/passes/test_batch_norm_fusion.py b/backends/xnnpack/test/passes/test_batch_norm_fusion.py index 98e9547c47a..59d0e0a2072 100644 --- a/backends/xnnpack/test/passes/test_batch_norm_fusion.py +++ b/backends/xnnpack/test/passes/test_batch_norm_fusion.py @@ -20,11 +20,17 @@ class TestBatchNormFusion(unittest.TestCase): class ModelConvBN(torch.nn.Module): def __init__( - self, in_features: int, out_features: int, kernel_size: Tuple[int, int] + self, + in_features: int, + out_features: int, + kernel_size: Tuple[int, int], + transpose: bool, ): super().__init__() - self.conv2d = torch.nn.Conv2d(in_features, out_features, kernel_size) + op = torch.nn.ConvTranspose2d if transpose else torch.nn.Conv2d + self.conv2d = op(in_features, out_features, kernel_size) self.bn = torch.nn.BatchNorm2d(out_features) + self.forward(torch.randn(2, 2, 4, 4) * 2 + 2) # update the BN stats def forward(self, x): y = self.conv2d(x) @@ -34,25 +40,33 @@ def forward(self, x): return self.bn(y) def test_fp32_batch_norm_fusion(self): - ( - Tester(self.ModelConvBN(2, 2, (2, 2)).eval(), (torch.randn(2, 2, 4, 4),)) - .export() - .to_edge() - .run_passes(self.PassStage) - .check_count({self.bn_name: 1}) - .run_method_and_compare_outputs() - ) + for transpose in [False, True]: + ( + Tester( + self.ModelConvBN(2, 2, (2, 2), transpose).eval(), + (torch.randn(2, 2, 4, 4),), + ) + .export() + .to_edge() + .run_passes(self.PassStage) + .check_count({self.bn_name: 1}) + .run_method_and_compare_outputs() + ) def test_q8_batch_norm_fusion(self): - ( - Tester(self.ModelConvBN(2, 2, (2, 2)).eval(), (torch.randn(2, 2, 4, 4),)) - .quantize() - .export() - .to_edge() - .run_passes(self.PassStage) - .check_count({self.bn_name: 1}) - .run_method_and_compare_outputs() - ) + for transpose in [False, True]: + ( + Tester( + self.ModelConvBN(2, 2, (2, 2), transpose).eval(), + (torch.randn(2, 2, 4, 4),), + ) + .quantize() + .export() + .to_edge() + .run_passes(self.PassStage) + .check_count({self.bn_name: 1}) + .run_method_and_compare_outputs() + ) def test_fp32_batch_norm_no_fusion_doesnt_partition(self): """ diff --git a/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py b/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py index fe781972e34..c1438b29213 100644 --- a/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py +++ b/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py @@ -23,6 +23,9 @@ class TestChannelsLastTaggedReshapePass(unittest.TestCase): OpSequencesAddConv2d(0, 0).eval(): 0, OpSequencesAddConv2d(1, 1).eval(): 2, OpSequencesAddConv2d(2, 2).eval(): 2, + OpSequencesAddConv2d(0, 0, True).eval(): 0, + OpSequencesAddConv2d(1, 1, True).eval(): 2, + OpSequencesAddConv2d(2, 2, True).eval(): 2, } to_copy_name = "executorch_exir_dialects_edge__ops_aten__to_copy_default" quant_name = "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default" diff --git a/backends/xnnpack/test/test_xnnpack_utils_classes.py b/backends/xnnpack/test/test_xnnpack_utils_classes.py index 50a1914a56b..548eb3b309d 100644 --- a/backends/xnnpack/test/test_xnnpack_utils_classes.py +++ b/backends/xnnpack/test/test_xnnpack_utils_classes.py @@ -14,17 +14,19 @@ class OpSequencesAddConv2d(torch.nn.Module): followed by an add to separate the sequences """ - def __init__(self, num_sequences, ops_per_sequence): + def __init__(self, num_sequences, ops_per_sequence, transpose=False): super().__init__() self.num_ops = num_sequences * ops_per_sequence self.num_sequences = num_sequences self.op_sequence = torch.nn.ModuleList() + + op = torch.nn.ConvTranspose2d if transpose else torch.nn.Conv2d for _ in range(num_sequences): inner = torch.nn.ModuleList() for _ in range(ops_per_sequence): inner.append( - torch.nn.Conv2d( + op( in_channels=1, out_channels=1, kernel_size=(3, 3), diff --git a/devtools/size_analysis_tool/size_analysis_tool_test.py b/devtools/size_analysis_tool/size_analysis_tool_test.py index 96feae7e423..00e1c9567a4 100644 --- a/devtools/size_analysis_tool/size_analysis_tool_test.py +++ b/devtools/size_analysis_tool/size_analysis_tool_test.py @@ -80,20 +80,6 @@ def forward(self, x): "shape": [2, 4, 3, 3, 3], "num_bytes": 864, }, - # ConvTranspose2d Weight - 32: { - "dtype": "float32", - "element_size": 4, - "shape": [2, 4, 2, 2], - "num_bytes": 128, - }, - # ConvTranspose2d Bias - 4: { - "dtype": "float32", - "element_size": 4, - "shape": [4], - "num_bytes": 16, - }, # Conv3d Bias 2: { "dtype": "float32", @@ -111,5 +97,5 @@ def forward(self, x): for k, v in exepected_tensor_data[tensor["numel"]].items(): self.assertEqual(tensor[k], v) - # Two delegate blobs: sigmoid and conv2d + # Two delegate blobs: sigmoid and conv2d/conv_transpose2d self.assertEqual(len(size_information["delegate_blob_data"]), 2)