Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions backends/cadence/aot/compiler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,6 @@ def contains_node_with_matching_target(
return any(node.target == op_target for node in nodes)


def is_quantized_tensor(x: torch.Tensor) -> bool:
"""
Return true if the tensor x is quantized
"""
return x.is_quantized


def get_scale(x: torch.Tensor) -> torch.Tensor:
"""
Return the scale of a quantized tensor as a float32 tensor.
Expand Down
189 changes: 3 additions & 186 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,15 @@
import math
import operator
from operator import neg
from typing import cast, Dict, Iterable, Optional, Sequence, Set, Tuple
from typing import cast, Dict, Iterable, Optional, Sequence, Tuple

import torch
import torch.fx
from executorch.backends.cadence.aot.compiler_utils import (
get_shape,
get_tensor_from_attr,
get_transposed_dims,
get_zero_point,
is_node_with_op,
is_quantized_tensor,
quantize_tensor_multiplier,
)
from executorch.backends.cadence.aot.fuse_ops import (
Expand Down Expand Up @@ -772,186 +770,6 @@ def call_operator(self, op, args, kwargs, meta):
return super().call_operator(target, new_args, kwargs, meta)


# TODO(matthiascremon): this is a fuse op, not a replace op
class ReplaceConvWithChannelLastConv:
"""
Convolution op in pytorch expects NCHW layout for input, weight, and output
tensors. However, if the input and output to the convolution op are originally
in NWHC layout, and are then permuted to conform to NCHW layout, we can fuse
the two permute ops with the convolution op, and call the NHWC layout
convolution op.
"""

def __init__(self):
self.counter = 0
self.graph_module = None

def __call__(self, graph_module: torch.fx.GraphModule):
self.replace_conv_with_nhwc_conv(graph_module)

def conv_layout_is_nhwc(self, node: torch.fx.Node) -> bool:
"""
Return true if the convolution input and output are connected to permute
ops, and the input/output to/from the permute ops is NHWC layout tensor.
"""
# There must only be a single user of the output node (which must be a
# permute/tranpsose op). The input of the convolution must be connected
# to a permute op, and that permute op should have a single user.
conv_inp = node.args[0]
assert isinstance(conv_inp, torch.fx.Node)
if len(node.users) != 1 or len(conv_inp.users) != 1:
return False

# Get the input and output (permute/transpose) nodes of the convolution
conv_user = list(node.users.keys())[0]
assert isinstance(conv_user, torch.fx.Node)
pt_nodes: Set[torch.fx.Node] = {conv_inp, conv_user}

# Any node in pt_nodes must not be a placeholder.
if contains_placeholder_or_param(pt_nodes):
return False

# Determine if the convolution is 1d or 2d. The output tensor must be
# 3- or 4-dimensional
out_shape = get_shape(self.graph_module, node)
assert out_shape is not None
out_dims = len(out_shape)
assert out_dims in {3, 4}, "Only supports conv1d and conv2d"
conv1d = out_dims == 3

# Get the possible targets for the nodes in pt_nodes. Since conv1d has
# 3-dimensional input and output tensors, the nodes in pt_nodes could
# be either permute or transpose op. For conv2d, the nodes in pt_nodes
# must be permute ops.
p_target = exir_ops.edge.aten.permute_copy.default
t_target = exir_ops.edge.aten.transpose_copy.int
pt_targets = [p_target] + ([t_target] if conv1d else [])

# If any node in pt_nodes is not permute op (or tranpose op for conv1d),
# bail.
if any(x.target not in pt_targets for x in pt_nodes):
return False

# Now we need to determine the dimension permutations:
# If the input had NHWC layout, which was then permuted/transposed
# by a permute/transpose op to NCHW layout, the permutation must be
# [0, 3, 2, 1] (or [0, 2, 1] for conv1d).
# If the output had NCHW layout, and was then permuted to NHWC layout,
# the permutation must be [0, 2, 3, 1] (or [0, 2, 1] for conv1d).
nhwc_permute_order = {
node.args[0]: [0, 2, 1] if conv1d else [0, 3, 1, 2],
list(node.users.keys())[0]: [0, 2, 1] if conv1d else [0, 2, 3, 1],
}
for x in pt_nodes:
order = (
x.args[1]
if x.target == p_target
else get_transposed_dims(x, list(range(out_dims)))
)
if order != nhwc_permute_order[x]:
return False

return True

def replace_conv_with_nhwc_conv(self, graph_module: torch.fx.GraphModule):
self.graph_module = graph_module
graph = graph_module.graph
for node in graph.nodes:
# We are only interested in convolution nodes that have NHWC layout
if node.target not in {
exir_ops.edge.cadence.quantized_conv_nchw.default,
exir_ops.edge.cadence.convolution.default,
exir_ops.edge.cadence.quantized_transposed_conv.default,
exir_ops.edge.cadence.transposed_convolution.default,
} or not self.conv_layout_is_nhwc(node):
continue

# Get the args of convolution op
args = list(node.args)
# The input is connected to a permute/transpose op that converts the
# NHWC layout to NCHW layout. The input of the permute op will become
# this convolution op's input.
in_tp = args[0]
args[0] = in_tp.args[0]
# The weight is in NHWC layout. Permute it to NHWC layout.
weight_tensor = get_tensor_from_attr(graph_module, args[1])
assert isinstance(weight_tensor, torch.Tensor)
# We cannot directly permute a per-channel quantized tensor. We will
# dequantize it, permute the fp32 tensor, and then requantize the
# permuted tensor.
if (
is_quantized_tensor(weight_tensor)
and weight_tensor.qscheme() == torch.per_channel_affine
):
# We have already asserted during quantizing conv op that the
# quantization axis is 0.
dequant_weight = weight_tensor.dequantize()
dequant_weight = (
dequant_weight.permute([0, 2, 1])
if dequant_weight.dim() == 3
else dequant_weight.permute([0, 2, 3, 1])
)
weight_tensor = torch.quantize_per_channel(
dequant_weight.contiguous(),
weight_tensor.q_per_channel_scales(),
weight_tensor.q_per_channel_zero_points(),
0,
weight_tensor.dtype,
)
else:
weight_tensor = (
weight_tensor.permute([0, 2, 1])
if weight_tensor.dim() == 3
else weight_tensor.permute([0, 2, 3, 1])
)
# Make the weight tensor contiguous, since we have permuted it.
weight_tensor = weight_tensor.contiguous()
# Add the permuted weight into the graph, and update the weight in
# args.
with graph.inserting_before(node):
weight_name = f"_weight_nhwc_{self.counter}"
graph_module.register_buffer(weight_name, weight_tensor)
weight = graph.get_attr(weight_name)
args[1] = weight

# The 'channel_last' arg is True. It is the last arg.
args[-1] = True
# Now update the convolution node args to mark it as NHWC convolution
node.args = tuple(args)

# Replace all the uses of the permute op connected to the output op
# with this convolution.
out_tp = list(node.users.keys())[0]
out_tp.replace_all_uses_with(node)
node.meta = out_tp.meta

# Erase the permute ops connected to the input and output of the
# convolution op.
graph.erase_node(in_tp)
graph.erase_node(out_tp)
self.counter += 1

graph_module.recompile()


# This pass needs to be reworked to be compatible with PT2. It is an optimization
# pass anyway, so move it to opt level 2.
# TODO: T213724613 update and improve this pass.
# @register_cadence_pass(CadencePassAttribute(opt_level=2))
class ReplaceConvWithChannelLastConvPass(ExportPass):
"""
Replace the ATen convolution op with custom conv op with NCHW or NHWC layout
input tensors, depending on the presence of permute/transpose ops connected
to the input tensor.
"""

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
result = ReplaceAtenConvolutionWithCadenceConvolutionPass()(graph_module)
assert result is not None
ReplaceConvWithChannelLastConv()(result.graph_module)
return result


@register_cadence_pass(CadencePassAttribute(opt_level=2))
class ReplaceTrivialConvWithLinear(ExportPass):
"""
Expand Down Expand Up @@ -1131,7 +949,7 @@ def transpose_dims(


@register_cadence_pass(CadencePassAttribute(opt_level=3))
class ForceChannelLastForConvPass(ExportPassWithTransposeHelper):
class ReplaceConvWithChannelLastConvPass(ExportPassWithTransposeHelper):
def change_nchw_to_nhwc(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue:
shape = proxy.to_tensor().shape
if len(shape) == 3:
Expand Down Expand Up @@ -2441,9 +2259,8 @@ class CadenceReplaceOpsInGraph:
ReplaceRepeatWithCatPass,
ReplacePadWithCatPass,
ReplaceConstantPadNdWithSlicePass,
ReplaceConvWithChannelLastConvPass,
ReplaceAtenConvolutionWithCadenceConvolutionPass,
ForceChannelLastForConvPass,
ReplaceConvWithChannelLastConvPass,
ReplaceTrivialConvWithLinear,
ReplaceConvWithIm2RowAndLinear,
ReplaceTransposedConvWithLinearPass,
Expand Down
16 changes: 8 additions & 8 deletions backends/cadence/aot/tests/test_replace_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
)
from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match
from executorch.backends.cadence.aot.replace_ops import (
ForceChannelLastForConvPass,
MakeSliceAndCatDimOutermostPass,
ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass,
ReplaceAddMMWithLinearPass,
ReplaceAtenApproxGeluWithApproxGeluPass,
ReplaceAtenConvolutionWithCadenceConvolutionPass,
ReplaceConstantPadNdWithSlicePass,
ReplaceConvolutionOptionalArgsWithConcreteArgsPass,
ReplaceConvWithChannelLastConvPass,
ReplaceConvWithIm2RowAndLinear,
ReplaceEmptyTensorsWithFullPass,
ReplaceFunctionallyEquivalentOpTargets,
Expand Down Expand Up @@ -1454,7 +1454,7 @@ def test_replace_linear_like_conv(self) -> None:
)


class TestForceChannelLastForConvPass(unittest.TestCase):
class TestReplaceConvWithChannelLastConvPass(unittest.TestCase):
def create_conv1d_graphmodule(
self, channels_last: Optional[bool] = None
) -> torch.fx.GraphModule:
Expand Down Expand Up @@ -1489,7 +1489,7 @@ def test_conv1d_default_channel_last(self) -> None:
self.assertEqual(count_node(gm, exir_ops.edge.aten.transpose_copy.int), 0)

# Apply replacement pass.
p = ForceChannelLastForConvPass()
p = ReplaceConvWithChannelLastConvPass()
gm_after_replacement = p.call(gm).graph_module
# Check that no replacement was made.
self.assertEqual(
Expand All @@ -1514,7 +1514,7 @@ def test_conv1d_no_transpose_if_already_channel_last(self) -> None:
self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1)

# Apply replacement pass.
p = ForceChannelLastForConvPass()
p = ReplaceConvWithChannelLastConvPass()
gm_after_replacement = p.call(gm).graph_module
# Check that no replacement was made.
self.assertEqual(
Expand Down Expand Up @@ -1566,7 +1566,7 @@ def test_convolution_default_channel_last(self) -> None:
self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0)

# Apply replacement pass.
p = ForceChannelLastForConvPass()
p = ReplaceConvWithChannelLastConvPass()
gm_after_replacement = p.call(gm).graph_module
# Check that no replacement was made.
self.assertEqual(
Expand All @@ -1591,7 +1591,7 @@ def test_no_transpose_if_already_channel_last(self) -> None:
self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1)

# Apply replacement pass.
p = ForceChannelLastForConvPass()
p = ReplaceConvWithChannelLastConvPass()
gm_after_replacement = p.call(gm).graph_module
# Check that no replacement was made.
self.assertEqual(
Expand Down Expand Up @@ -1692,7 +1692,7 @@ def test_quantized_convolution_default_channel_last(self) -> None:
self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0)

# Apply replacement pass.
p = ForceChannelLastForConvPass()
p = ReplaceConvWithChannelLastConvPass()
gm_after_replacement = p.call(gm).graph_module
# Check that no replacement was made.
self.assertEqual(
Expand All @@ -1717,7 +1717,7 @@ def test_no_transpose_if_already_quantized_conv_channel_last(self) -> None:
)

# Apply replacement pass.
p = ForceChannelLastForConvPass()
p = ReplaceConvWithChannelLastConvPass()
gm_after_replacement = p.call(gm).graph_module
# Check that no replacement was made.
self.assertEqual(
Expand Down
Loading