Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from enum import Enum
from typing import Optional, Tuple

Expand Down Expand Up @@ -161,6 +163,9 @@ def requires_nhwc_input(self, node: torch.fx.Node) -> bool:
return node.target in self.memory_sensitive_ops_nhwc

def requires_nchw_inputs(self, node: torch.fx.Node) -> bool:
if node.target == exir_ops.edge.aten.view_copy.default:
return True

return node.target in self.memory_sensitive_ops_nchw

def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool:
Expand Down
1 change: 1 addition & 0 deletions backends/xnnpack/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,5 @@
op_sub,
op_tanh,
op_to_copy,
op_view_copy,
)
12 changes: 2 additions & 10 deletions backends/xnnpack/operators/op_skip_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import Dict

import torch
Expand Down Expand Up @@ -59,16 +61,6 @@ class OpTCopyDefault(OpSkipOps):
target = "aten.t_copy.default"


@register_node_visitor
class OpViewCopyDefault(OpSkipOps):
"""
currently, do nothing if node is view_copy.default
need to handle this later on, currently view it as one of skip ops
"""

target = "aten.view_copy.default"


@register_node_visitor
class OpSymSizeInt(OpSkipOps):
"""
Expand Down
96 changes: 96 additions & 0 deletions backends/xnnpack/operators/op_view_copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import Dict

import torch
from executorch.backends.xnnpack.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
XNNGraph,
XNNStaticReshape,
XNode,
)
from executorch.backends.xnnpack.utils.utils import (
check_or_raise,
get_input_node,
PERM_NCHW_TO_NHWC,
)


@register_node_visitor
class ViewCopyVisitor(NodeVisitor):
target = "aten.view_copy.default"

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
xnn_graph: XNNGraph,
vals_to_ids: Dict[torch.fx.Node, int],
debug_handle: int,
) -> None:
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)

input_node = get_input_node(node, 0)

# input
input_id = vals_to_ids[input_node]

# output
output_id = vals_to_ids[node]

# input shape
check_or_raise(
"val" in input_node.meta,
"Missing val in tensor metadata for input when serializing XNNStaticReshape",
)

# output shape
check_or_raise(
"val" in node.meta,
"Missing val in tensor metadata for input when serializing XNNStaticReshape",
)

new_shape = node.args[1]
check_or_raise(
all(isinstance(d, int) for d in new_shape),
"Symbolic reshape parameter is not supported in XNNStaticReshape",
)

# PyTorch uses -1 for inferred dims, whereas XNNPACK expects 0.
new_shape = tuple(d if d != -1 else 0 for d in new_shape)

# Handle NCHW dim order - if this op is in NCHW order, we need to permute the
# view shape correspondingly.
if "XNN_NHWC_NODE" in node.meta:
check_or_raise(len(new_shape) == 4, "Invalid NCHW shape")
new_shape = [new_shape[PERM_NCHW_TO_NHWC[n]] for n in range(4)]

num_dynamic_dims = sum(1 for d in new_shape if d == 0)

check_or_raise(
num_dynamic_dims <= 1,
"XNNPACK reshape only supports 1 dynamic dimension.",
)

ser_node = XNode(
xnode_union=XNNStaticReshape(
num_dims=len(new_shape),
new_shape=new_shape,
input_id=input_id,
output_id=output_id,
flags=0,
),
debug_handle=debug_handle,
)
xnn_graph.xnodes.append(ser_node)
2 changes: 2 additions & 0 deletions backends/xnnpack/partition/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
TanhConfig,
ToDimOrderCopyConfig,
UpsampleBilinear2dConfig,
ViewCopyConfig,
)
from executorch.backends.xnnpack.partition.config.node_configs import (
BatchNormConfig,
Expand Down Expand Up @@ -115,6 +116,7 @@
SquareRootConfig,
SubConfig,
UpsampleBilinear2dConfig,
ViewCopyConfig,
# Quant/Dequant Op Configs
QuantizedPerTensorConfig,
DeQuantizedPerTensorConfig,
Expand Down
37 changes: 37 additions & 0 deletions backends/xnnpack/partition/config/generic_node_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,43 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]:
return [ConfigPrecisionType.FP32]


class ViewCopyConfig(GenericNodePartitionerConfig):
target_name = "view_copy.default"

def supported_precision_types(self) -> List[ConfigPrecisionType]:
return [ConfigPrecisionType.FP32]

def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
"""
XNNPACK's static_reshape only supports 1 dynamic dimension.
"""
if not self.check_common_constraints(node, ep):
return False

new_shape = node.args[1]

# Check for symbolic dims. They aren't lowerable to XNNPACK currently.
symbolic_dim_indices = [
i for i, d in enumerate(new_shape) if not isinstance(d, int)
]
if not all(isinstance(n, int) for n in new_shape):
why(
node,
reason=f"Symbolic reshape is not supported. Output shape is {new_shape} and dims at {symbolic_dim_indices} are symbolic.",
)
return False

dynamic_dim_indices = [i for i, d in enumerate(new_shape) if d == -1]
if len(dynamic_dim_indices) > 1:
why(
node,
reason=f"Only a single inferred dimension is supported. Output shape is {new_shape} and dims {dynamic_dim_indices} are inferred.",
)
return False

return True


class FloorConfig(GenericNodePartitionerConfig):
target_name = "floor.default"

Expand Down
Loading
Loading