Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions backends/nxp/aten_passes/move_activation_before_concat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright 2025 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import torch

from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec

from torch.fx import GraphModule, Node
from torch.fx.passes.infra.pass_base import PassBase, PassResult


class MoveActivationBeforeConcat(PassBase):
"""Move some operators around in the following pattern.
This is a common pattern that emerges from the conversion of separable convolutions.
This optimization works together with joint quantization of compute nodes and activations. Without it,
it is not beneficial.

│ │ │ │
┌──────▼──────┐ ┌──────▼──────┐ ┌──────▼──────┐ ┌──────▼──────┐
│ aten.conv2d │ ... │ aten.conv2d │ │ aten.conv2d │ ... │ aten.conv2d │
└──────┬──────┘ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘
└───────┐ ┌──────┘ │ │
┌──▼─────▼─┐ replace with ┌─────▼─────┐ ┌─────▼─────┐
│ aten.cat │ ──────────────► │ aten.relu │ ... │ aten.relu │
└────┬─────┘ └─────┬─────┘ └─────┬─────┘
│ └───────┐ ┌───────┘
┌─────▼─────┐ ┌──▼─────▼─┐
│ aten.relu │ │ aten.cat │
└─────┬─────┘ └────┬─────┘
│ │
"""

def __init__(self, neutron_target_spec: NeutronTargetSpec):
self.neutron_target_spec = neutron_target_spec

def call(self, module: GraphModule) -> bool:
def _is_concat(node_: Node) -> bool:
return (
node_.op == "call_function"
and node_.target == torch.ops.aten.cat.default
)

made_changes = False

for node in module.graph.nodes:
if not _is_concat(node):
continue # Not cat node.

cat_node = node
activation = next(iter(cat_node.users))

# Check if all cat inputs nodes are conv 2D or linear 2D type and their only user is cat.
if not all(
self.neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten(
input_node
)
and len(input_node.users) == 1
for input_node in cat_node.all_input_nodes
):
continue

# Check if following activation is supported on Neutron as fused activation.
if not (
len(cat_node.users) == 1
and self.neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten(
activation
)
):
continue

# Loop all Cat input nodes and insert new activation after node.
for input_node in cat_node.all_input_nodes:
with module.graph.inserting_after(input_node):
new_activation = module.graph.call_function(
activation.target,
args=(),
kwargs=activation.kwargs,
)

new_activation.meta["source_fn_stack"] = [
(
new_activation.name,
activation.meta["source_fn_stack"][-1][-1],
)
]
new_activation.meta["val"] = input_node.meta["val"]

# Replace the uses of the input node with the new activation node.
input_node.replace_all_uses_with(new_activation)
new_activation.args = (input_node, *activation.args[1:])

# Replace the uses of the activation node with the cat node.
activation.replace_all_uses_with(cat_node)

module.graph.erase_node(activation)

made_changes = True

return PassResult(module, made_changes)
9 changes: 8 additions & 1 deletion backends/nxp/aten_passes/neutron_aten_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from executorch.backends.nxp.aten_passes.fuse_linear_and_add_pass import (
FuseLinearAndAddPass,
)
from executorch.backends.nxp.aten_passes.move_activation_before_concat import (
MoveActivationBeforeConcat,
)
from executorch.backends.nxp.aten_passes.remove_nodes_with_known_outputs import (
RemoveNodesWithKnownOutputs,
)
Expand All @@ -25,6 +28,7 @@
from executorch.backends.nxp.aten_passes.split_gru_based_on_num_layers import (
SplitGRUBasedOnNumLayers,
)
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
from executorch.exir.pass_manager import PassManager
from torch import nn
from torch.fx.passes.infra.pass_base import PassResult
Expand All @@ -34,14 +38,17 @@

class NeutronAtenPassManager(PassManager):

def __init__(self, passes: list[PassType] = None):
def __init__(
self, neutron_target_spec: NeutronTargetSpec, passes: list[PassType] = None
):
passes: list[PassType] = passes or [
FuseBatchNormWithConvPass(),
FuseBatchNormWithLinearPass(),
SplitGroupConvolution(),
SplitGRUBasedOnNumLayers(),
RemoveNodesWithKnownOutputs(),
FuseLinearAndAddPass(),
MoveActivationBeforeConcat(neutron_target_spec),
]

super().__init__(passes)
Expand Down

This file was deleted.

8 changes: 0 additions & 8 deletions backends/nxp/backend/ir/tflite_optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@

from executorch.backends.nxp.backend.ir import logger
from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig
from executorch.backends.nxp.backend.ir.tflite_optimizer.optimizations.move_relu_before_concat import (
MoveActivationBeforeConcatenation,
)
from executorch.backends.nxp.backend.ir.tflite_optimizer.optimizations.permute_fully_connected_weights_after_reshape import (
PermuteFullyConnectedWeightsAfterReshape,
)
Expand All @@ -29,8 +26,6 @@ class Optimization(Enum):

PERMUTE_FULLY_CONNECTED_WEIGHTS_AFTER_RESHAPE = 12

MOVE_ACTIVATION_BEFORE_CONCAT = 15


class Optimizer:
"""
Expand Down Expand Up @@ -68,9 +63,6 @@ def __init__(
Optimization.PERMUTE_FULLY_CONNECTED_WEIGHTS_AFTER_RESHAPE: PermuteFullyConnectedWeightsAfterReshape(
builder, conversion_config
),
Optimization.MOVE_ACTIVATION_BEFORE_CONCAT: MoveActivationBeforeConcatenation(
builder, conversion_config
),
}

def optimize(
Expand Down
10 changes: 9 additions & 1 deletion backends/nxp/quantizer/neutron_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
from executorch.backends.nxp.quantizer.patterns import (
AbsPattern,
ActivationsConcatClusterPattern,
AdaptiveAvgPoolPattern,
AddmmPattern,
AddTensorPattern,
Expand Down Expand Up @@ -225,13 +226,16 @@ def __init__(self, neutron_target_spec: NeutronTargetSpec):
self.op_to_applied_quantizer = {
pt: False for q in self.quantizers for pt in q.pattern.partition_types()
}
self.cluster_quantizers = [
NeutronAtenQuantizer(ActivationsConcatClusterPattern(self), static_qconfig)
]

def transform_for_annotation(
self, model: torch.fx.GraphModule
) -> torch.fx.GraphModule:
model.graph.eliminate_dead_code() # Remove dead code to simplify the graph for the passes.

model = NeutronAtenPassManager()(model).graph_module
model = NeutronAtenPassManager(self.neutron_target_spec)(model).graph_module

model.graph.eliminate_dead_code() # Remove dead code again, in case it was created by the passes.

Expand All @@ -240,6 +244,10 @@ def transform_for_annotation(
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
self._annotate_inputs(model)

# Annotate node clusters in model
for cluster_quantizer in self.cluster_quantizers:
cluster_quantizer.annotate(model)

nodes = list(model.graph.nodes)
for node in nodes:
if (
Expand Down
Loading
Loading