Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions backends/nxp/backend/edge_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,29 @@
# LICENSE file in the root directory of this source tree.

import torch

from torch.fx import GraphModule, Node
from torch.nn import Parameter


def _is_dequantize(node_: Node) -> bool:
return node_.op == "call_function" and node_.target.__name__ in [
"dequantize_per_tensor.default",
"quantized_decomposed.dequantize_per_tensor.default",
"dequantize_per_channel.default",
"quantized_decomposed.dequantize_per_channel.default",
]


def _is_quantize(node_: Node) -> bool:
return node_.op == "call_function" and node_.target.__name__ in [
"quantize_per_tensor.default",
"quantized_decomposed.quantize_per_tensor.default",
"quantize_per_channel.default",
"quantized_decomposed.quantize_per_channel.default",
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not compare the targets directly? It would be more reliable.



def input_tensor(node: Node, input_index: int) -> torch.Tensor:
if len(node.all_input_nodes) <= input_index:
raise IndexError
Expand Down Expand Up @@ -62,12 +81,6 @@ def node_is_effectively_static_tensor(
if node_is_static_tensor(node, parameters_mapping):
return True

def _is_dequantize(node_: Node) -> bool:
return node_.target.__name__ in {
"quantized_decomposed.dequantize_per_tensor.default",
"quantized_decomposed.dequantize_per_channel.default",
}

return _is_dequantize(node) and node_is_static_tensor(
node.args[0], parameters_mapping
)
Expand Down

This file was deleted.

8 changes: 0 additions & 8 deletions backends/nxp/backend/ir/tflite_optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@

from executorch.backends.nxp.backend.ir import logger
from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig
from executorch.backends.nxp.backend.ir.tflite_optimizer.optimizations.fuse_activation_functions import (
FuseActivationFunctions,
)
from executorch.backends.nxp.backend.ir.tflite_optimizer.optimizations.move_relu_before_concat import (
MoveActivationBeforeConcatenation,
)
Expand All @@ -27,8 +24,6 @@


class Optimization(Enum):
FUSE_ACTIVATION_FUNCTIONS = 1

FUSE_TRANSPOSE_OPERATORS = 5
REMOVE_IDENTITY_TRANSPOSE_OPERATORS = 6

Expand Down Expand Up @@ -64,9 +59,6 @@ def __init__(
self._builder = builder

self.optimization_map = {
Optimization.FUSE_ACTIVATION_FUNCTIONS: FuseActivationFunctions(
builder, conversion_config
),
Optimization.FUSE_TRANSPOSE_OPERATORS: FuseTransposeOperators(
builder, conversion_config
),
Expand Down
Loading
Loading