Skip to content

Commit f6f95bf

Browse files
committed
Support to converter Exynos's own IR from EXIR
- It can be converted op or graph Signed-off-by: jiseong.oh <[email protected]>
1 parent 93709cb commit f6f95bf

File tree

5 files changed

+225
-8
lines changed

5 files changed

+225
-8
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) 2025 Samsung Electronics Co. LTD
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from . import node_visitor
8+
9+
__all__ = [
10+
node_visitor,
11+
]
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (c) 2025 Samsung Electronics Co. LTD
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict, Optional
8+
9+
import torch
10+
from executorch.backends.samsung.builders.utils import (
11+
get_map_dtype,
12+
get_tensor,
13+
get_tensor_type,
14+
)
15+
from executorch.backends.samsung.serialization.enn_graph_schema import EnnGraph
16+
from torch.export import ExportedProgram
17+
18+
19+
class NodeVisitor:
20+
"""
21+
Node visitor pattern for visiting nodes in an edge IR graph
22+
"""
23+
24+
def __init__(self, exported_program: ExportedProgram) -> None:
25+
self._exported_program = exported_program or None
26+
27+
@property
28+
def exported_program(self) -> ExportedProgram:
29+
return self._exported_program
30+
31+
def define_node(self, node: torch.fx.Node, enn_graph: EnnGraph):
32+
raise NotImplementedError("NodeVisitor must be extended!")
33+
34+
def define_tensor(
35+
self,
36+
node: torch.fx.Node,
37+
enn_graph: EnnGraph,
38+
vals_to_ids: Dict[torch.fx.Node, int],
39+
swap_nc_for_weights: bool = False,
40+
output_idx: Optional[int] = None,
41+
) -> int:
42+
if node in vals_to_ids and (output_idx is None or output_idx == 0):
43+
return vals_to_ids[node]
44+
45+
# Get tensor basic information
46+
tensor = get_tensor(self.exported_program, node)
47+
48+
if output_idx is not None:
49+
tensor = tensor[output_idx]
50+
51+
tensor_type = get_tensor_type(self.exported_program, node)
52+
data_type = get_map_dtype(tensor.dtype)
53+
54+
const_data = None
55+
56+
dims = [1] if len(tensor.size()) == 0 else list(tensor.size())
57+
58+
enn_tensor_id = enn_graph.define_tensor(
59+
node.name,
60+
dims,
61+
data_type,
62+
tensor_type.name,
63+
const_data,
64+
)
65+
assert enn_tensor_id is not None
66+
vals_to_ids[node] = enn_tensor_id
67+
68+
return enn_tensor_id
69+
70+
71+
_node_visitor_dict = {}
72+
73+
74+
def register_node_visitor(visitor):
75+
assert (
76+
isinstance(visitor, type)
77+
and issubclass(visitor, NodeVisitor)
78+
and hasattr(visitor, "target")
79+
), f"Illformed NodeVisitor subclass, can't register!, got: {visitor}"
80+
if isinstance(visitor.target, str):
81+
_node_visitor_dict[visitor.target] = visitor
82+
elif isinstance(visitor.target, (list, tuple)):
83+
for target in visitor.target:
84+
_node_visitor_dict[target] = visitor
85+
else:
86+
raise TypeError(
87+
f"target of vistor should be str|Tuple[str]|List[str], not{type(visitor.target)}"
88+
)
89+
90+
91+
def get_node_visitors(*args) -> Dict[str, NodeVisitor]:
92+
node_visitors = {}
93+
"""
94+
Create a new class instance at runtime, and put them in a dict
95+
"""
96+
for target, visitor in _node_visitor_dict.items():
97+
assert callable(visitor), "Expecting a callable class, "
98+
f"but got {visitor} of type {type(visitor)}"
99+
node_visitors[target] = visitor(*args)
100+
return node_visitors

backends/samsung/builders/utils.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) 2025 Samsung Electronics Co. LTD
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from enum import Enum
8+
9+
import torch
10+
from executorch.backends.samsung.utils.utils import is_graph_input, is_graph_output
11+
from executorch.backends.transforms.utils import get_param_tensor, is_param_node
12+
13+
from torch.export import ExportedProgram
14+
15+
DATA_TYPE_STR_MAPPING = {
16+
torch.int8: "INT8",
17+
torch.uint8: "UINT8",
18+
torch.int16: "INT16",
19+
torch.uint16: "UINT16",
20+
torch.int32: "INT32",
21+
torch.int64: "INT64",
22+
torch.float16: "FLOAT16",
23+
torch.float32: "FLOAT32",
24+
}
25+
26+
TORCH_TYPE_QTYPE_MAPPING = {
27+
torch.int8: torch.qint8,
28+
torch.uint8: torch.quint8,
29+
torch.int32: torch.qint32,
30+
}
31+
32+
33+
class TensorType(Enum):
34+
INPUT = 0
35+
OUTPUT = 1
36+
CONSTANT = 2
37+
FEATUREMAP = 3
38+
39+
40+
def get_tensor_type(exported_program: ExportedProgram, tensor: torch.fx.Node) -> str:
41+
if is_graph_input(exported_program, tensor):
42+
return TensorType.INPUT
43+
elif is_graph_output(tensor):
44+
return TensorType.OUTPUT
45+
elif is_param_node(exported_program, tensor):
46+
return TensorType.CONSTANT
47+
else:
48+
return TensorType.FEATUREMAP
49+
50+
51+
def get_map_dtype(dtype):
52+
if dtype not in DATA_TYPE_STR_MAPPING:
53+
raise RuntimeError("Data type cannot be decided: ", dtype)
54+
return DATA_TYPE_STR_MAPPING[dtype]
55+
56+
57+
def get_tensor(exported_program: ExportedProgram, node: torch.fx.Node):
58+
if not is_param_node(exported_program, node):
59+
return node.meta["val"]
60+
tensor = get_param_tensor(exported_program, node)
61+
return tensor.contiguous()
62+
63+
64+
def affine_type_to_str(ttype: TensorType):
65+
return str(ttype).removeprefix("TensorType.")

backends/samsung/enn_preprocess.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99

1010
import executorch.backends.samsung.python.PyEnnWrapperAdaptor as PyEnnWrapper
1111
import torch
12+
from executorch.backends.samsung.builders.node_visitor import get_node_visitors
1213
from executorch.backends.samsung.serialization.compile_options import (
1314
ENN_COMPILE_OPTION_TITLE,
1415
)
1516
from executorch.backends.samsung.serialization.enn_graph_schema import EnnGraph
17+
from executorch.backends.samsung.utils.utils import get_compile_spec
1618

1719
from executorch.exir.backend.backend_details import (
1820
BackendDetails,
@@ -38,23 +40,23 @@ def preprocess(
3840
)
3941
enn_wrapper.Init(option_spec.value)
4042

41-
enn_preprocess_passes = PassManager(passes=[])
42-
43-
# 2 make enn graph
4443
enn_preprocess_passes = PassManager(passes=[])
4544
pass_result = enn_preprocess_passes(edge_program.graph_module)
45+
assert pass_result is not None
4646

4747
enn_graph = EnnGraph()
4848
enn_graph.init("UnknownName", "")
49-
# 3 node visitors
50-
node_visitors = []
49+
# node visitors
50+
node_visitors = get_node_visitors(edge_program)
5151

5252
vals_to_ids: Dict[torch.fx.Node, int] = {}
5353
for node in pass_result.graph_module.graph.nodes:
5454
if node.op == "call_function":
5555
logging.warning(f"Visiting: {node}, {node.target.__name__}")
5656
if node.target.__name__ in node_visitors:
57-
pass
57+
node_visitors[node.target.__name__].define_node(
58+
node, enn_graph, vals_to_ids
59+
)
5860
else:
5961
raise RuntimeError(
6062
f"{node.target.__name__}" " is not supported in ENN Delegate"
@@ -68,11 +70,13 @@ def preprocess(
6870
else:
6971
raise RuntimeError(f"{node.op}" " is not supported in ENN Delegate")
7072

71-
# 4 Compile Graph
73+
# Compile Graph
7274
enn_wrapper.Destroy()
7375
enn_graph.finish()
7476
ser_buf = enn_graph.serialize()
7577
enn_context_binary = enn_wrapper.Compile(ser_buf)
7678
assert enn_context_binary is not None and len(enn_context_binary) > 0
7779
enn_wrapper.Destroy()
78-
return PreprocessResult(processed_bytes=bytes(enn_context_binary), debug_handle_map={})
80+
return PreprocessResult(
81+
processed_bytes=bytes(enn_context_binary), debug_handle_map={}
82+
)

backends/samsung/utils/utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright (c) 2025 Samsung Electronics Co. LTD
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import List
8+
9+
import torch
10+
11+
from executorch.backends.transforms.utils import is_param_node
12+
from executorch.exir.backend.backend_details import CompileSpec
13+
14+
from torch.export.exported_program import ExportedProgram
15+
16+
17+
def get_compile_spec(
18+
compile_specs: List[CompileSpec], spec_name: str, required=False
19+
) -> CompileSpec:
20+
for spec in compile_specs:
21+
if spec_name == spec.key:
22+
return spec
23+
assert not required, f"Require {spec_name} but it doesn't exist."
24+
25+
26+
def is_graph_input(exported_program: ExportedProgram, node: torch.fx.Node) -> bool:
27+
return node.op == "placeholder" and not is_param_node(exported_program, node)
28+
29+
30+
def is_graph_output(node: torch.fx.Node) -> bool:
31+
# skip getitem node
32+
for user in node.users.keys():
33+
if user.op == "output" or (
34+
user.target.__name__ == "getitem" and is_graph_output(user)
35+
):
36+
return True
37+
return False

0 commit comments

Comments
 (0)