diff --git a/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py b/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py index c1bc3a54f7c..179006bc1b6 100644 --- a/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py +++ b/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + from enum import Enum from typing import Optional, Tuple @@ -161,6 +163,9 @@ def requires_nhwc_input(self, node: torch.fx.Node) -> bool: return node.target in self.memory_sensitive_ops_nhwc def requires_nchw_inputs(self, node: torch.fx.Node) -> bool: + if node.target == exir_ops.edge.aten.view_copy.default: + return True + return node.target in self.memory_sensitive_ops_nchw def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool: diff --git a/backends/xnnpack/operators/__init__.py b/backends/xnnpack/operators/__init__.py index 02a46a6fc47..65b5a2327ae 100644 --- a/backends/xnnpack/operators/__init__.py +++ b/backends/xnnpack/operators/__init__.py @@ -56,4 +56,5 @@ op_sub, op_tanh, op_to_copy, + op_view_copy, ) diff --git a/backends/xnnpack/operators/op_skip_ops.py b/backends/xnnpack/operators/op_skip_ops.py index 19df74e77ac..04be2b274b2 100644 --- a/backends/xnnpack/operators/op_skip_ops.py +++ b/backends/xnnpack/operators/op_skip_ops.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + from typing import Dict import torch @@ -59,16 +61,6 @@ class OpTCopyDefault(OpSkipOps): target = "aten.t_copy.default" -@register_node_visitor -class OpViewCopyDefault(OpSkipOps): - """ - currently, do nothing if node is view_copy.default - need to handle this later on, currently view it as one of skip ops - """ - - target = "aten.view_copy.default" - - @register_node_visitor class OpSymSizeInt(OpSkipOps): """ diff --git a/backends/xnnpack/operators/op_view_copy.py b/backends/xnnpack/operators/op_view_copy.py new file mode 100644 index 00000000000..5a8bf342eab --- /dev/null +++ b/backends/xnnpack/operators/op_view_copy.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import Dict + +import torch +from executorch.backends.xnnpack.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( + XNNGraph, + XNNStaticReshape, + XNode, +) +from executorch.backends.xnnpack.utils.utils import ( + check_or_raise, + get_input_node, + PERM_NCHW_TO_NHWC, +) + + +@register_node_visitor +class ViewCopyVisitor(NodeVisitor): + target = "aten.view_copy.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + xnn_graph: XNNGraph, + vals_to_ids: Dict[torch.fx.Node, int], + debug_handle: int, + ) -> None: + self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids) + + input_node = get_input_node(node, 0) + + # input + input_id = vals_to_ids[input_node] + + # output + output_id = vals_to_ids[node] + + # input shape + check_or_raise( + "val" in input_node.meta, + "Missing val in tensor metadata for input when serializing XNNStaticReshape", + ) + + # output shape + check_or_raise( + "val" in node.meta, + "Missing val in tensor metadata for input when serializing XNNStaticReshape", + ) + + new_shape = node.args[1] + check_or_raise( + all(isinstance(d, int) for d in new_shape), + "Symbolic reshape parameter is not supported in XNNStaticReshape", + ) + + # PyTorch uses -1 for inferred dims, whereas XNNPACK expects 0. + new_shape = tuple(d if d != -1 else 0 for d in new_shape) + + # Handle NCHW dim order - if this op is in NCHW order, we need to permute the + # view shape correspondingly. + if "XNN_NHWC_NODE" in node.meta: + check_or_raise(len(new_shape) == 4, "Invalid NCHW shape") + new_shape = [new_shape[PERM_NCHW_TO_NHWC[n]] for n in range(4)] + + num_dynamic_dims = sum(1 for d in new_shape if d == 0) + + check_or_raise( + num_dynamic_dims <= 1, + "XNNPACK reshape only supports 1 dynamic dimension.", + ) + + ser_node = XNode( + xnode_union=XNNStaticReshape( + num_dims=len(new_shape), + new_shape=new_shape, + input_id=input_id, + output_id=output_id, + flags=0, + ), + debug_handle=debug_handle, + ) + xnn_graph.xnodes.append(ser_node) diff --git a/backends/xnnpack/partition/config/__init__.py b/backends/xnnpack/partition/config/__init__.py index 5427b3a7838..f4026b31fb1 100644 --- a/backends/xnnpack/partition/config/__init__.py +++ b/backends/xnnpack/partition/config/__init__.py @@ -55,6 +55,7 @@ TanhConfig, ToDimOrderCopyConfig, UpsampleBilinear2dConfig, + ViewCopyConfig, ) from executorch.backends.xnnpack.partition.config.node_configs import ( BatchNormConfig, @@ -115,6 +116,7 @@ SquareRootConfig, SubConfig, UpsampleBilinear2dConfig, + ViewCopyConfig, # Quant/Dequant Op Configs QuantizedPerTensorConfig, DeQuantizedPerTensorConfig, diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index 434fce1d73a..14c114952e9 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -380,6 +380,43 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]: return [ConfigPrecisionType.FP32] +class ViewCopyConfig(GenericNodePartitionerConfig): + target_name = "view_copy.default" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32] + + def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: + """ + XNNPACK's static_reshape only supports 1 dynamic dimension. + """ + if not self.check_common_constraints(node, ep): + return False + + new_shape = node.args[1] + + # Check for symbolic dims. They aren't lowerable to XNNPACK currently. + symbolic_dim_indices = [ + i for i, d in enumerate(new_shape) if not isinstance(d, int) + ] + if not all(isinstance(n, int) for n in new_shape): + why( + node, + reason=f"Symbolic reshape is not supported. Output shape is {new_shape} and dims at {symbolic_dim_indices} are symbolic.", + ) + return False + + dynamic_dim_indices = [i for i, d in enumerate(new_shape) if d == -1] + if len(dynamic_dim_indices) > 1: + why( + node, + reason=f"Only a single inferred dimension is supported. Output shape is {new_shape} and dims {dynamic_dim_indices} are inferred.", + ) + return False + + return True + + class FloorConfig(GenericNodePartitionerConfig): target_name = "floor.default" diff --git a/backends/xnnpack/test/ops/test_view_copy.py b/backends/xnnpack/test/ops/test_view_copy.py new file mode 100644 index 00000000000..5a22e7c28a9 --- /dev/null +++ b/backends/xnnpack/test/ops/test_view_copy.py @@ -0,0 +1,290 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import unittest + +import torch +from executorch.backends.xnnpack.test.tester import Export, Tester +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export import Dim + + +class TestViewCopy(unittest.TestCase): + class View(torch.nn.Module): + def __init__(self, new_shape): + super().__init__() + self.new_shape = new_shape + + def forward(self, x): + z = x.view(self.new_shape) + return z + + def test_fp16_view_copy(self): + inputs = (torch.randn(4, 4).to(torch.float16),) + ( + Tester(self.View((2, 8)), inputs) + .export() + .check_node_count({torch.ops.aten.view.default: 1}) + .to_edge_transform_and_lower() + .check_node_count( + { + torch.ops.higher_order.executorch_call_delegate: 1, + exir_ops.edge.aten.view_copy.default: 0, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp32_view_copy(self): + inputs = (torch.randn(4, 4),) + ( + Tester(self.View((2, 8)), inputs) + .export() + .check_node_count({torch.ops.aten.view.default: 1}) + .to_edge_transform_and_lower() + .check_node_count( + { + torch.ops.higher_order.executorch_call_delegate: 1, + exir_ops.edge.aten.view_copy.default: 0, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp32_view_copy_inferred_dim(self): + inputs = (torch.randn(4, 4),) + ( + Tester(self.View((-1, 8)), inputs) + .export() + .check_node_count({torch.ops.aten.view.default: 1}) + .to_edge_transform_and_lower() + .check_node_count( + { + torch.ops.higher_order.executorch_call_delegate: 1, + exir_ops.edge.aten.view_copy.default: 0, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp32_view_copy_dynamic_shape(self): + inputs = (torch.randn(4, 4, 6),) + for dynamic_dim_index in range(len(inputs[0].shape)): + dynamic_shapes = { + "x": {dynamic_dim_index: Dim("x", min=1, max=10) * 2}, + } + + # Test as min and max bounds. + test_inputs = [ + (inputs[0].clone(),), + (inputs[0].clone(),), + ] + test_inputs[0][0][dynamic_dim_index] = 2 + test_inputs[1][0][dynamic_dim_index] = 20 + + # Non-dynamic dimensions are halved in the view. + view_size = [n // 2 for n in inputs[0].shape] + view_size[dynamic_dim_index] = -1 + + tester = ( + Tester(self.View(view_size), inputs) + .export(Export(dynamic_shapes=dynamic_shapes)) + .check_node_count({torch.ops.aten.view.default: 1}) + .to_edge_transform_and_lower() + .check_node_count( + { + torch.ops.higher_order.executorch_call_delegate: 1, + exir_ops.edge.aten.view_copy.default: 0, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + for test_input in test_inputs: + tester.run_method_and_compare_outputs(inputs=test_input) + + def test_fp32_view_copy_unsupported_dynamism(self): + class SymbolicView(torch.nn.Module): + def forward(self, x): + return x.view(x.shape[0] // 2, x.shape[1] * 2) + + inputs = (torch.randn(4, 4),) + dynamic_shapes = { + "x": {1: Dim("x", min=1, max=10) * 2}, + } + ( + Tester(SymbolicView(), inputs) + .export(Export(dynamic_shapes=dynamic_shapes)) + .check_node_count({torch.ops.aten.view.default: 1}) + .to_edge_transform_and_lower() + .check_node_count( + { # Expect no delegation as the view has two dynamic dimensions. + torch.ops.higher_order.executorch_call_delegate: 0, + exir_ops.edge.aten.view_copy.default: 1, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp32_view_copy_static_symbolic_arg(self): + class SymbolicView(torch.nn.Module): + def forward(self, x): + return x.view(x.shape[0] // 2, x.shape[1] * 2) + + inputs = (torch.randn(4, 4),) + ( + Tester(SymbolicView(), inputs) + .export() + .check_node_count({torch.ops.aten.view.default: 1}) + .to_edge_transform_and_lower() + .check_node_count( + { + # Expect delegatation, as the the symbolic shape expressions will + # be resolved to static values in the absense of dynamic shapes. + torch.ops.higher_order.executorch_call_delegate: 1, + exir_ops.edge.aten.view_copy.default: 0, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp32_view_copy_increase_rank(self): + inputs = (torch.randn(4, 4),) + ( + Tester(self.View((1, 2, 4, 2)), inputs) + .export() + .check_node_count({torch.ops.aten.view.default: 1}) + .to_edge_transform_and_lower() + .check_node_count( + { + torch.ops.higher_order.executorch_call_delegate: 1, + exir_ops.edge.aten.view_copy.default: 0, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp32_view_copy_increase_rank_dynamic(self): + test_inputs = ( + (torch.randn(2, 4),), + (torch.randn(10, 4),), + ) + dynamic_shapes = { + "x": {0: Dim("x", min=1, max=10) * 2}, + } + inputs = (torch.randn(4, 4),) + tester = ( + Tester(self.View((1, 2, 4, -1)), inputs) + .export(Export(dynamic_shapes=dynamic_shapes)) + .check_node_count({torch.ops.aten.view.default: 1}) + .to_edge_transform_and_lower() + .check_node_count( + { + torch.ops.higher_order.executorch_call_delegate: 1, + exir_ops.edge.aten.view_copy.default: 0, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + for test_input in test_inputs: + tester.run_method_and_compare_outputs(inputs=test_input) + + def test_fp32_view_copy_decrease_rank(self): + inputs = (torch.randn(4, 4),) + ( + Tester(self.View((-1)), inputs) + .export() + .check_node_count({torch.ops.aten.view.default: 1}) + .to_edge_transform_and_lower() + .check_node_count( + { + torch.ops.higher_order.executorch_call_delegate: 1, + exir_ops.edge.aten.view_copy.default: 0, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp32_view_copy_decrease_rank_dynamic(self): + test_inputs = ( + (torch.randn(2, 2, 4),), + (torch.randn(2, 10, 4),), + ) + dynamic_shapes = { + "x": {1: Dim("x", min=1, max=10) * 2}, + } + inputs = (torch.randn(2, 4, 4),) + tester = ( + Tester(self.View((-1)), inputs) + .export(Export(dynamic_shapes=dynamic_shapes)) + .check_node_count({torch.ops.aten.view.default: 1}) + .to_edge_transform_and_lower() + .check_node_count( + { + torch.ops.higher_order.executorch_call_delegate: 1, + exir_ops.edge.aten.view_copy.default: 0, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + for test_input in test_inputs: + tester.run_method_and_compare_outputs(inputs=test_input) + + def test_fp32_view_copy_nhwc(self): + class ViewNHWC(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + self.conv2 = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + y = self.conv1(x) + y = y.view(1, 3, 3, -1) + y = self.conv2(y) + return y.view(1, 3, 2, -1) + + inputs = (torch.randn(1, 3, 8, 8),) + ( + Tester(ViewNHWC(), inputs) + .export() + .dump_artifact() + .check_node_count({torch.ops.aten.view.default: 2}) + .to_edge_transform_and_lower() + .dump_artifact() + .check_node_count( + { + torch.ops.higher_order.executorch_call_delegate: 1, + exir_ops.edge.aten.view_copy.default: 0, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) diff --git a/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py b/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py index a73a0eb0ad1..d823af9735e 100644 --- a/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py +++ b/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + import unittest import torch @@ -23,6 +25,7 @@ is_quant, is_tagged_as_implicit_q_dq, ) +from executorch.exir.dialects._ops import ops as exir_ops class TestChannelsLastTaggedReshapePass(unittest.TestCase): @@ -480,3 +483,153 @@ def test_q_dq_nodes_around_copy_are_tagged(self): # Compare outputs tester.run_method_and_compare_outputs() + + def test_fp32_channels_last_tagged_reshape_pass_nhwc_view(self): + # Views are always run in NCHW for now. + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + self.conv2 = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + y = self.conv1(x) + y = y.view((1, 3, 3, -1)) + return self.conv2(y) + + inputs = (torch.randn(1, 3, 8, 8),) + ( + Tester(Model(), inputs) + .export() + .to_edge() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 2, + exir_ops.edge.aten.view_copy.default: 1, + } + ) + .run_passes(self.PassStage) + .run_method_and_compare_outputs() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 2, + exir_ops.edge.aten.view_copy.default: 1, + # 4 dim order conversions - a pair at the start and end and a pair + # around the view. + exir_ops.edge.aten._to_copy.default: 4, + } + ) + ) + + def test_fp32_channels_last_tagged_reshape_pass_nchw_view_channel_modified(self): + # View cannot run in NHWC because channel and/or batch are modified. + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + self.conv2 = torch.nn.Conv2d(6, 3, 3) + + def forward(self, x): + y = self.conv1(x) + y = y.view((1, 6, 6, -1)) + return self.conv2(y) + + inputs = (torch.randn(1, 3, 8, 8),) + ( + Tester(Model(), inputs) + .export() + .to_edge() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 2, + exir_ops.edge.aten.view_copy.default: 1, + } + ) + .run_passes(self.PassStage) + .run_method_and_compare_outputs() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 2, + exir_ops.edge.aten.view_copy.default: 1, + exir_ops.edge.aten._to_copy.default: 4, + } + ) + ) + + def test_fp32_channels_last_tagged_reshape_pass_nchw_view_batch_modified(self): + # View cannot run in NHWC because channel and/or batch are modified. + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + self.conv2 = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + y = self.conv1(x) + y = y.view((2, 3, 6, -1)) + return self.conv2(y) + + inputs = (torch.randn(1, 3, 8, 8),) + ( + Tester(Model(), inputs) + .export() + .to_edge() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 2, + exir_ops.edge.aten.view_copy.default: 1, + } + ) + .run_passes(self.PassStage) + .run_method_and_compare_outputs() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 2, + exir_ops.edge.aten.view_copy.default: 1, + exir_ops.edge.aten._to_copy.default: 4, + } + ) + ) + + def test_fp32_channels_last_tagged_reshape_pass_flatten_view(self): + # View cannot run in NHWC because tensor rank changes. + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + self.linear1 = torch.nn.Linear(36 * 3, 1) + + def forward(self, x): + y = self.conv1(x) + y = y.view((x.shape[0], -1)) + return self.linear1(y) + + inputs = (torch.randn(1, 3, 8, 8),) + tester = ( + Tester(Model(), inputs) + .export() + .to_edge() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 1, + exir_ops.edge.aten.view_copy.default: 1, + } + ) + .run_passes(self.PassStage) + .run_method_and_compare_outputs() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 1, + exir_ops.edge.aten.view_copy.default: 1, + exir_ops.edge.aten._to_copy.default: 2, + } + ) + ) + + # Verify view is not tagged. + graph = tester.get_artifact().exported_program().module().graph + view_nodes = [ + n for n in graph.nodes if n.target == exir_ops.edge.aten.view_copy.default + ] + self.assertEqual(1, len(view_nodes)) + self.assertTrue(ChannelsLastTaggedReshapePass(None).is_nchw_node(view_nodes[0]))