From 7ceda5600a429f1195ec6bc076bc2fbbec67541b Mon Sep 17 00:00:00 2001 From: Gregory Comer Date: Wed, 3 Dec 2025 22:34:24 -0800 Subject: [PATCH] Add view_copy/static_reshape support to XNNPACK delegate (#7959) Summary: Add support for delegating view_copy in the XNNPACK delegate via the XNN static_reshape operator. This includes support for up to one dynamic dimension. It also includes conditional support for NHWC, so long as batch and channel are fixed (if batch or channel is modified, the reshape has to be done in the native dim order). Test Plan: I've added test coverage for view_copy, including dynamic shape support and practitioner constraints, to backends/xnnpack/test/ops/test_view_copy.py.If also added additional tests to cover the new view logic in channels_last_tagged_reshape_pass. Differential Revision: D68691788 Pulled By: GregoryComer --- .../channels_last_tagged_reshape_pass.py | 5 + backends/xnnpack/operators/__init__.py | 1 + backends/xnnpack/operators/op_skip_ops.py | 12 +- backends/xnnpack/operators/op_view_copy.py | 96 ++++++ backends/xnnpack/partition/config/__init__.py | 2 + .../partition/config/generic_node_configs.py | 37 +++ backends/xnnpack/test/ops/test_view_copy.py | 290 ++++++++++++++++++ .../test_channels_last_tagged_reshape.py | 153 +++++++++ 8 files changed, 586 insertions(+), 10 deletions(-) create mode 100644 backends/xnnpack/operators/op_view_copy.py create mode 100644 backends/xnnpack/test/ops/test_view_copy.py diff --git a/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py b/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py index c1bc3a54f7c..179006bc1b6 100644 --- a/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py +++ b/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + from enum import Enum from typing import Optional, Tuple @@ -161,6 +163,9 @@ def requires_nhwc_input(self, node: torch.fx.Node) -> bool: return node.target in self.memory_sensitive_ops_nhwc def requires_nchw_inputs(self, node: torch.fx.Node) -> bool: + if node.target == exir_ops.edge.aten.view_copy.default: + return True + return node.target in self.memory_sensitive_ops_nchw def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool: diff --git a/backends/xnnpack/operators/__init__.py b/backends/xnnpack/operators/__init__.py index 02a46a6fc47..65b5a2327ae 100644 --- a/backends/xnnpack/operators/__init__.py +++ b/backends/xnnpack/operators/__init__.py @@ -56,4 +56,5 @@ op_sub, op_tanh, op_to_copy, + op_view_copy, ) diff --git a/backends/xnnpack/operators/op_skip_ops.py b/backends/xnnpack/operators/op_skip_ops.py index 19df74e77ac..04be2b274b2 100644 --- a/backends/xnnpack/operators/op_skip_ops.py +++ b/backends/xnnpack/operators/op_skip_ops.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + from typing import Dict import torch @@ -59,16 +61,6 @@ class OpTCopyDefault(OpSkipOps): target = "aten.t_copy.default" -@register_node_visitor -class OpViewCopyDefault(OpSkipOps): - """ - currently, do nothing if node is view_copy.default - need to handle this later on, currently view it as one of skip ops - """ - - target = "aten.view_copy.default" - - @register_node_visitor class OpSymSizeInt(OpSkipOps): """ diff --git a/backends/xnnpack/operators/op_view_copy.py b/backends/xnnpack/operators/op_view_copy.py new file mode 100644 index 00000000000..5a8bf342eab --- /dev/null +++ b/backends/xnnpack/operators/op_view_copy.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import Dict + +import torch +from executorch.backends.xnnpack.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( + XNNGraph, + XNNStaticReshape, + XNode, +) +from executorch.backends.xnnpack.utils.utils import ( + check_or_raise, + get_input_node, + PERM_NCHW_TO_NHWC, +) + + +@register_node_visitor +class ViewCopyVisitor(NodeVisitor): + target = "aten.view_copy.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + xnn_graph: XNNGraph, + vals_to_ids: Dict[torch.fx.Node, int], + debug_handle: int, + ) -> None: + self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids) + + input_node = get_input_node(node, 0) + + # input + input_id = vals_to_ids[input_node] + + # output + output_id = vals_to_ids[node] + + # input shape + check_or_raise( + "val" in input_node.meta, + "Missing val in tensor metadata for input when serializing XNNStaticReshape", + ) + + # output shape + check_or_raise( + "val" in node.meta, + "Missing val in tensor metadata for input when serializing XNNStaticReshape", + ) + + new_shape = node.args[1] + check_or_raise( + all(isinstance(d, int) for d in new_shape), + "Symbolic reshape parameter is not supported in XNNStaticReshape", + ) + + # PyTorch uses -1 for inferred dims, whereas XNNPACK expects 0. + new_shape = tuple(d if d != -1 else 0 for d in new_shape) + + # Handle NCHW dim order - if this op is in NCHW order, we need to permute the + # view shape correspondingly. + if "XNN_NHWC_NODE" in node.meta: + check_or_raise(len(new_shape) == 4, "Invalid NCHW shape") + new_shape = [new_shape[PERM_NCHW_TO_NHWC[n]] for n in range(4)] + + num_dynamic_dims = sum(1 for d in new_shape if d == 0) + + check_or_raise( + num_dynamic_dims <= 1, + "XNNPACK reshape only supports 1 dynamic dimension.", + ) + + ser_node = XNode( + xnode_union=XNNStaticReshape( + num_dims=len(new_shape), + new_shape=new_shape, + input_id=input_id, + output_id=output_id, + flags=0, + ), + debug_handle=debug_handle, + ) + xnn_graph.xnodes.append(ser_node) diff --git a/backends/xnnpack/partition/config/__init__.py b/backends/xnnpack/partition/config/__init__.py index 5427b3a7838..f4026b31fb1 100644 --- a/backends/xnnpack/partition/config/__init__.py +++ b/backends/xnnpack/partition/config/__init__.py @@ -55,6 +55,7 @@ TanhConfig, ToDimOrderCopyConfig, UpsampleBilinear2dConfig, + ViewCopyConfig, ) from executorch.backends.xnnpack.partition.config.node_configs import ( BatchNormConfig, @@ -115,6 +116,7 @@ SquareRootConfig, SubConfig, UpsampleBilinear2dConfig, + ViewCopyConfig, # Quant/Dequant Op Configs QuantizedPerTensorConfig, DeQuantizedPerTensorConfig, diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index 434fce1d73a..14c114952e9 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -380,6 +380,43 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]: return [ConfigPrecisionType.FP32] +class ViewCopyConfig(GenericNodePartitionerConfig): + target_name = "view_copy.default" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32] + + def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: + """ + XNNPACK's static_reshape only supports 1 dynamic dimension. + """ + if not self.check_common_constraints(node, ep): + return False + + new_shape = node.args[1] + + # Check for symbolic dims. They aren't lowerable to XNNPACK currently. + symbolic_dim_indices = [ + i for i, d in enumerate(new_shape) if not isinstance(d, int) + ] + if not all(isinstance(n, int) for n in new_shape): + why( + node, + reason=f"Symbolic reshape is not supported. Output shape is {new_shape} and dims at {symbolic_dim_indices} are symbolic.", + ) + return False + + dynamic_dim_indices = [i for i, d in enumerate(new_shape) if d == -1] + if len(dynamic_dim_indices) > 1: + why( + node, + reason=f"Only a single inferred dimension is supported. Output shape is {new_shape} and dims {dynamic_dim_indices} are inferred.", + ) + return False + + return True + + class FloorConfig(GenericNodePartitionerConfig): target_name = "floor.default" diff --git a/backends/xnnpack/test/ops/test_view_copy.py b/backends/xnnpack/test/ops/test_view_copy.py new file mode 100644 index 00000000000..5a22e7c28a9 --- /dev/null +++ b/backends/xnnpack/test/ops/test_view_copy.py @@ -0,0 +1,290 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import unittest + +import torch +from executorch.backends.xnnpack.test.tester import Export, Tester +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export import Dim + + +class TestViewCopy(unittest.TestCase): + class View(torch.nn.Module): + def __init__(self, new_shape): + super().__init__() + self.new_shape = new_shape + + def forward(self, x): + z = x.view(self.new_shape) + return z + + def test_fp16_view_copy(self): + inputs = (torch.randn(4, 4).to(torch.float16),) + ( + Tester(self.View((2, 8)), inputs) + .export() + .check_node_count({torch.ops.aten.view.default: 1}) + .to_edge_transform_and_lower() + .check_node_count( + { + torch.ops.higher_order.executorch_call_delegate: 1, + exir_ops.edge.aten.view_copy.default: 0, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp32_view_copy(self): + inputs = (torch.randn(4, 4),) + ( + Tester(self.View((2, 8)), inputs) + .export() + .check_node_count({torch.ops.aten.view.default: 1}) + .to_edge_transform_and_lower() + .check_node_count( + { + torch.ops.higher_order.executorch_call_delegate: 1, + exir_ops.edge.aten.view_copy.default: 0, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp32_view_copy_inferred_dim(self): + inputs = (torch.randn(4, 4),) + ( + Tester(self.View((-1, 8)), inputs) + .export() + .check_node_count({torch.ops.aten.view.default: 1}) + .to_edge_transform_and_lower() + .check_node_count( + { + torch.ops.higher_order.executorch_call_delegate: 1, + exir_ops.edge.aten.view_copy.default: 0, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp32_view_copy_dynamic_shape(self): + inputs = (torch.randn(4, 4, 6),) + for dynamic_dim_index in range(len(inputs[0].shape)): + dynamic_shapes = { + "x": {dynamic_dim_index: Dim("x", min=1, max=10) * 2}, + } + + # Test as min and max bounds. + test_inputs = [ + (inputs[0].clone(),), + (inputs[0].clone(),), + ] + test_inputs[0][0][dynamic_dim_index] = 2 + test_inputs[1][0][dynamic_dim_index] = 20 + + # Non-dynamic dimensions are halved in the view. + view_size = [n // 2 for n in inputs[0].shape] + view_size[dynamic_dim_index] = -1 + + tester = ( + Tester(self.View(view_size), inputs) + .export(Export(dynamic_shapes=dynamic_shapes)) + .check_node_count({torch.ops.aten.view.default: 1}) + .to_edge_transform_and_lower() + .check_node_count( + { + torch.ops.higher_order.executorch_call_delegate: 1, + exir_ops.edge.aten.view_copy.default: 0, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + for test_input in test_inputs: + tester.run_method_and_compare_outputs(inputs=test_input) + + def test_fp32_view_copy_unsupported_dynamism(self): + class SymbolicView(torch.nn.Module): + def forward(self, x): + return x.view(x.shape[0] // 2, x.shape[1] * 2) + + inputs = (torch.randn(4, 4),) + dynamic_shapes = { + "x": {1: Dim("x", min=1, max=10) * 2}, + } + ( + Tester(SymbolicView(), inputs) + .export(Export(dynamic_shapes=dynamic_shapes)) + .check_node_count({torch.ops.aten.view.default: 1}) + .to_edge_transform_and_lower() + .check_node_count( + { # Expect no delegation as the view has two dynamic dimensions. + torch.ops.higher_order.executorch_call_delegate: 0, + exir_ops.edge.aten.view_copy.default: 1, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp32_view_copy_static_symbolic_arg(self): + class SymbolicView(torch.nn.Module): + def forward(self, x): + return x.view(x.shape[0] // 2, x.shape[1] * 2) + + inputs = (torch.randn(4, 4),) + ( + Tester(SymbolicView(), inputs) + .export() + .check_node_count({torch.ops.aten.view.default: 1}) + .to_edge_transform_and_lower() + .check_node_count( + { + # Expect delegatation, as the the symbolic shape expressions will + # be resolved to static values in the absense of dynamic shapes. + torch.ops.higher_order.executorch_call_delegate: 1, + exir_ops.edge.aten.view_copy.default: 0, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp32_view_copy_increase_rank(self): + inputs = (torch.randn(4, 4),) + ( + Tester(self.View((1, 2, 4, 2)), inputs) + .export() + .check_node_count({torch.ops.aten.view.default: 1}) + .to_edge_transform_and_lower() + .check_node_count( + { + torch.ops.higher_order.executorch_call_delegate: 1, + exir_ops.edge.aten.view_copy.default: 0, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp32_view_copy_increase_rank_dynamic(self): + test_inputs = ( + (torch.randn(2, 4),), + (torch.randn(10, 4),), + ) + dynamic_shapes = { + "x": {0: Dim("x", min=1, max=10) * 2}, + } + inputs = (torch.randn(4, 4),) + tester = ( + Tester(self.View((1, 2, 4, -1)), inputs) + .export(Export(dynamic_shapes=dynamic_shapes)) + .check_node_count({torch.ops.aten.view.default: 1}) + .to_edge_transform_and_lower() + .check_node_count( + { + torch.ops.higher_order.executorch_call_delegate: 1, + exir_ops.edge.aten.view_copy.default: 0, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + for test_input in test_inputs: + tester.run_method_and_compare_outputs(inputs=test_input) + + def test_fp32_view_copy_decrease_rank(self): + inputs = (torch.randn(4, 4),) + ( + Tester(self.View((-1)), inputs) + .export() + .check_node_count({torch.ops.aten.view.default: 1}) + .to_edge_transform_and_lower() + .check_node_count( + { + torch.ops.higher_order.executorch_call_delegate: 1, + exir_ops.edge.aten.view_copy.default: 0, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp32_view_copy_decrease_rank_dynamic(self): + test_inputs = ( + (torch.randn(2, 2, 4),), + (torch.randn(2, 10, 4),), + ) + dynamic_shapes = { + "x": {1: Dim("x", min=1, max=10) * 2}, + } + inputs = (torch.randn(2, 4, 4),) + tester = ( + Tester(self.View((-1)), inputs) + .export(Export(dynamic_shapes=dynamic_shapes)) + .check_node_count({torch.ops.aten.view.default: 1}) + .to_edge_transform_and_lower() + .check_node_count( + { + torch.ops.higher_order.executorch_call_delegate: 1, + exir_ops.edge.aten.view_copy.default: 0, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + for test_input in test_inputs: + tester.run_method_and_compare_outputs(inputs=test_input) + + def test_fp32_view_copy_nhwc(self): + class ViewNHWC(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + self.conv2 = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + y = self.conv1(x) + y = y.view(1, 3, 3, -1) + y = self.conv2(y) + return y.view(1, 3, 2, -1) + + inputs = (torch.randn(1, 3, 8, 8),) + ( + Tester(ViewNHWC(), inputs) + .export() + .dump_artifact() + .check_node_count({torch.ops.aten.view.default: 2}) + .to_edge_transform_and_lower() + .dump_artifact() + .check_node_count( + { + torch.ops.higher_order.executorch_call_delegate: 1, + exir_ops.edge.aten.view_copy.default: 0, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) diff --git a/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py b/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py index a73a0eb0ad1..d823af9735e 100644 --- a/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py +++ b/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + import unittest import torch @@ -23,6 +25,7 @@ is_quant, is_tagged_as_implicit_q_dq, ) +from executorch.exir.dialects._ops import ops as exir_ops class TestChannelsLastTaggedReshapePass(unittest.TestCase): @@ -480,3 +483,153 @@ def test_q_dq_nodes_around_copy_are_tagged(self): # Compare outputs tester.run_method_and_compare_outputs() + + def test_fp32_channels_last_tagged_reshape_pass_nhwc_view(self): + # Views are always run in NCHW for now. + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + self.conv2 = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + y = self.conv1(x) + y = y.view((1, 3, 3, -1)) + return self.conv2(y) + + inputs = (torch.randn(1, 3, 8, 8),) + ( + Tester(Model(), inputs) + .export() + .to_edge() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 2, + exir_ops.edge.aten.view_copy.default: 1, + } + ) + .run_passes(self.PassStage) + .run_method_and_compare_outputs() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 2, + exir_ops.edge.aten.view_copy.default: 1, + # 4 dim order conversions - a pair at the start and end and a pair + # around the view. + exir_ops.edge.aten._to_copy.default: 4, + } + ) + ) + + def test_fp32_channels_last_tagged_reshape_pass_nchw_view_channel_modified(self): + # View cannot run in NHWC because channel and/or batch are modified. + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + self.conv2 = torch.nn.Conv2d(6, 3, 3) + + def forward(self, x): + y = self.conv1(x) + y = y.view((1, 6, 6, -1)) + return self.conv2(y) + + inputs = (torch.randn(1, 3, 8, 8),) + ( + Tester(Model(), inputs) + .export() + .to_edge() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 2, + exir_ops.edge.aten.view_copy.default: 1, + } + ) + .run_passes(self.PassStage) + .run_method_and_compare_outputs() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 2, + exir_ops.edge.aten.view_copy.default: 1, + exir_ops.edge.aten._to_copy.default: 4, + } + ) + ) + + def test_fp32_channels_last_tagged_reshape_pass_nchw_view_batch_modified(self): + # View cannot run in NHWC because channel and/or batch are modified. + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + self.conv2 = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + y = self.conv1(x) + y = y.view((2, 3, 6, -1)) + return self.conv2(y) + + inputs = (torch.randn(1, 3, 8, 8),) + ( + Tester(Model(), inputs) + .export() + .to_edge() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 2, + exir_ops.edge.aten.view_copy.default: 1, + } + ) + .run_passes(self.PassStage) + .run_method_and_compare_outputs() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 2, + exir_ops.edge.aten.view_copy.default: 1, + exir_ops.edge.aten._to_copy.default: 4, + } + ) + ) + + def test_fp32_channels_last_tagged_reshape_pass_flatten_view(self): + # View cannot run in NHWC because tensor rank changes. + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + self.linear1 = torch.nn.Linear(36 * 3, 1) + + def forward(self, x): + y = self.conv1(x) + y = y.view((x.shape[0], -1)) + return self.linear1(y) + + inputs = (torch.randn(1, 3, 8, 8),) + tester = ( + Tester(Model(), inputs) + .export() + .to_edge() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 1, + exir_ops.edge.aten.view_copy.default: 1, + } + ) + .run_passes(self.PassStage) + .run_method_and_compare_outputs() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 1, + exir_ops.edge.aten.view_copy.default: 1, + exir_ops.edge.aten._to_copy.default: 2, + } + ) + ) + + # Verify view is not tagged. + graph = tester.get_artifact().exported_program().module().graph + view_nodes = [ + n for n in graph.nodes if n.target == exir_ops.edge.aten.view_copy.default + ] + self.assertEqual(1, len(view_nodes)) + self.assertTrue(ChannelsLastTaggedReshapePass(None).is_nchw_node(view_nodes[0]))