Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 80 additions & 2 deletions backends/arm/_passes/decompose_meandim_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass
from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT
from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
from executorch.exir.backend.utils import WhyNoPartitionReporter
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
Expand Down Expand Up @@ -50,6 +51,15 @@ def get_view(op):
raise RuntimeError(f"Can't get meandim decomposition for op {op}")


def get_quantization(op):
"""Returns quant and dequant op of same type (per_channel/ tensor) as op if op is a dequant node, None otherwise."""
if op in DQ_OPS:
# Input of op can be placeholder, can't use that to get quant node directly.
quant_type_index = DQ_OPS.index(op)
return Q_OPS[quant_type_index], op
return None


class DecomposeMeanDimPass(ArmPass):
"""
Decomposes a meandim into avg_pool and/or sum + mul (1/N) depending on which dims the mean is taken for:
Expand Down Expand Up @@ -121,6 +131,7 @@ def call_operator(self, op, args, kwargs, meta):
dims_to_reduce = [dim - 1 for dim in dims_to_reduce]

x = super().call_operator(view_op, (x, new_shape), {}, meta, True)
x = self._maybe_insert_q_dq_after(x, meta)

# Reduce (h,w) dims by avg pool if possible
x, dims_to_reduce = self._reduce_by_average_pool(op, x, dims_to_reduce, meta)
Expand All @@ -133,7 +144,7 @@ def call_operator(self, op, args, kwargs, meta):
dims_to_reduce = [dim + len(original_dims) - 1 for dim in dims_to_reduce]

x = super().call_operator(view_op, (x, temp_shape), {}, meta, True)

x = self._maybe_insert_q_dq_after(x, meta)
# Reduce remaining dims by sum
x = self._reduce_by_sum(op, x, dims_to_reduce, meta, dtype)

Expand All @@ -156,6 +167,45 @@ def _reduce_by_sum(self, op, input_node, dims, meta, dtype):
full = super().call_operator(
full_op, ([1] * len(output_shape), 1 / N), {"dtype": dtype}, meta, True
)
if (quant_ops := get_quantization(input_node.node.target)) is not None:
# Insert Q and DQ nodes after full op.
# Since the value of full is known, we can compute quant params such that dq(q_max_value)
q_op, dq_op = quant_ops
qmax = input_node.node.args[4]
full_quant_args = (
1 / (N * qmax), # Scale to map qmax to 1/N
0, # Zero point
*input_node.node.args[3:],
)
q_args = (full, *full_quant_args)
full = super().call_operator(
q_op,
q_args,
kwargs={},
meta=meta,
updated=True,
)
dq_args = (full, *full_quant_args)
full = super().call_operator(
dq_op, dq_args, kwargs={}, meta=meta, updated=True
)

# Insert Q and DQ nodes after sum op.
# Scale needs to be adjusted with N, since it was computed on data after the division with N.
sum_quant_args = (input_node.node.args[1] * N, *input_node.node.args[2:])
q_args = (sum, *sum_quant_args)
sum = super().call_operator(
q_op,
q_args,
kwargs={},
meta=meta,
updated=True,
)
dq_args = (sum, *sum_quant_args)
sum = super().call_operator(
dq_op, dq_args, kwargs={}, meta=meta, updated=True
)

return super().call_operator(mul_op, (sum, full), {}, meta, True)

def _reduce_by_average_pool(self, op, input_node, dims, meta):
Expand Down Expand Up @@ -190,10 +240,38 @@ def _reduce_by_average_pool(self, op, input_node, dims, meta):
)

if is_supported:
out = super().call_operator(avgpool_op, args, {}, meta, True)
out = self._maybe_insert_q_dq_after(out, meta)
return (
super().call_operator(avgpool_op, args, {}, meta, True),
out,
dims_to_reduce_by_sum,
)

else:
return input_node, dims

def _maybe_insert_q_dq_after(self, op, meta):
"""If the input node of op is a dequant node, insert a q-dq pair after op with identical quantization parameters."""

if len(op.node.all_input_nodes) > 1:
raise ValueError(
f"Expected one input to {op.node}, got inputs {op.node.all_input_nodes}"
)
input_node = op.node.all_input_nodes[0]
if (quant_ops := get_quantization(input_node.target)) is not None:
q_op, dq_op = quant_ops
quant_args = list(input_node.args[1:])
q_args = (op, *quant_args)
out = super().call_operator(
q_op,
q_args,
kwargs={},
meta=meta,
updated=True,
)
dq_args = (out, *quant_args)
return super().call_operator(
dq_op, dq_args, kwargs={}, meta=meta, updated=True
)
else:
return op
14 changes: 13 additions & 1 deletion backends/arm/test/ops/test_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
VgfPipeline,
)

aten_op = "torch.ops.aten.avg_pool2d.default"
aten_op = "avg_pool2d.default"
exir_op = "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"

input_t = Tuple[torch.Tensor]
Expand All @@ -34,6 +34,15 @@ def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)


class BecomesMeanInToEdge(torch.nn.Module):
"""This averagepool will be converted to mean when lowering to edge. This causes the decompose_meandim pass to not
trigger until the backend pipeline, which requires extra care.
"""

def forward(self, x: torch.Tensor):
return torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))


test_modules = {
"zeros": lambda: (AvgPool2d(4, 2, 0, False), (torch.zeros(1, 16, 50, 32),)),
"ones": lambda: (AvgPool2d(4, 2, 0, False, True), (torch.ones(1, 16, 50, 32),)),
Expand Down Expand Up @@ -110,6 +119,9 @@ def forward(self, *args, **kwargs):
AvgPool2d(3, (1, 3), 1, count_include_pad=False),
(torch.rand(1, 16, 54, 54),),
),
"becomes_mean_rank3": lambda: (BecomesMeanInToEdge(), (torch.rand(2, 8, 8),)),
"becomes_mean_rank4": lambda: (BecomesMeanInToEdge(), (torch.rand(2, 2, 8, 8),)),
"becomes_mean_rank5": lambda: (BecomesMeanInToEdge(), (torch.rand(2, 2, 8, 8),)),
}


Expand Down
Loading