Skip to content

Commit 93709cb

Browse files
committed
Serialization from EXIR
- For Exynos AI LiteCore, EXIR will be serialized and converted. Signed-off-by: chong-chen <[email protected]> Signed-off-by: jiseong.oh <[email protected]>
1 parent e50bee4 commit 93709cb

File tree

4 files changed

+229
-0
lines changed

4 files changed

+229
-0
lines changed

backends/samsung/enn_preprocess.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99

1010
import executorch.backends.samsung.python.PyEnnWrapperAdaptor as PyEnnWrapper
1111
import torch
12+
from executorch.backends.samsung.serialization.compile_options import (
13+
ENN_COMPILE_OPTION_TITLE,
14+
)
15+
from executorch.backends.samsung.serialization.enn_graph_schema import EnnGraph
16+
1217
from executorch.exir.backend.backend_details import (
1318
BackendDetails,
1419
CompileSpec,
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright (c) 2025 Samsung Electronics Co. LTD
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import json
8+
import os
9+
import tempfile
10+
11+
from dataclasses import dataclass
12+
from enum import IntEnum, unique
13+
from typing import Dict, List, Optional
14+
15+
import pkg_resources
16+
from executorch.exir._serialize._dataclass import _DataclassEncoder
17+
from executorch.exir._serialize._flatbuffer import _flatc_compile
18+
from executorch.exir.backend.backend_details import CompileSpec
19+
20+
21+
@unique
22+
class SamsungChipset(IntEnum):
23+
UNDEFINED_CHIP_V = 0
24+
E9955 = 9955
25+
26+
27+
@dataclass
28+
class DebugOption:
29+
name: str # option name as a key
30+
value: str
31+
32+
33+
@dataclass
34+
class EnnExecuTorchOptions:
35+
chipset: SamsungChipset = SamsungChipset.UNDEFINED_CHIP_V
36+
debug_options: Optional[List[DebugOption]] = None
37+
38+
39+
ENN_COMPILE_OPTION_TITLE = "enn_compile_options"
40+
COMPILE_OPTION_SCHEMA_NAME = "compile_options_def"
41+
42+
43+
def gen_samsung_backend_compile_spec_core(options: EnnExecuTorchOptions) -> CompileSpec:
44+
with tempfile.TemporaryDirectory() as d:
45+
# schema
46+
schema_path = os.path.join(d, "{}.fbs".format(COMPILE_OPTION_SCHEMA_NAME))
47+
with open(schema_path, "wb") as schema_file:
48+
schema_file.write(
49+
pkg_resources.resource_string(
50+
__name__, "{}.fbs".format(COMPILE_OPTION_SCHEMA_NAME)
51+
)
52+
)
53+
# dump json
54+
json_path = os.path.join(d, "{}.json".format(COMPILE_OPTION_SCHEMA_NAME))
55+
enn_options_json = json.dumps(options, cls=_DataclassEncoder, indent=4)
56+
with open(json_path, "wb") as json_file:
57+
json_file.write(enn_options_json.encode("ascii"))
58+
59+
_flatc_compile(d, schema_path, json_path)
60+
output_path = os.path.join(d, "{}.eeto".format(COMPILE_OPTION_SCHEMA_NAME))
61+
with open(output_path, "rb") as output_file:
62+
return CompileSpec(ENN_COMPILE_OPTION_TITLE, output_file.read())
63+
64+
65+
def gen_samsung_backend_compile_spec(
66+
chipset: str, debug_options: Optional[Dict[str, str]] = None
67+
):
68+
"""
69+
A function to generate an ExecuTorch binary for Samsung Backend.
70+
71+
Attributes:
72+
chipset (str): chipset name in SamsungChipset. For example, E9955 or e9955 both work.
73+
74+
Returns:
75+
CompileSpec: key is COMPILE_OPTION_SCHEMA_NAME, value is serialization binary of fb schema
76+
"""
77+
pass_debug_options = []
78+
if debug_options is not None:
79+
for key, value in debug_options.items():
80+
pass_debug_options.append(DebugOption(key, value))
81+
82+
option = EnnExecuTorchOptions(
83+
getattr(SamsungChipset, chipset.upper()),
84+
pass_debug_options,
85+
)
86+
87+
return gen_samsung_backend_compile_spec_core(option)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
//============================================================================
2+
//
3+
// Copyright (c) 2025 Samsung Electronics. All Rights Reserved.
4+
//
5+
// This source code is licensed under the BSD-style license found in the
6+
// LICENSE file in the root directory of this source tree.
7+
//
8+
//============================================================================
9+
10+
namespace enn_option;
11+
12+
// Identifier of a valid executor schema.
13+
file_identifier "EETO";
14+
// Extension of written files.
15+
file_extension "eeto";
16+
17+
table DebugOption {
18+
name: string;
19+
value: string;
20+
}
21+
22+
table EnnExecuTorchOptions {
23+
// The version of chipset. Specify the soc to compile and execute model.
24+
chipset: int;
25+
26+
// Debug options controlling behavior of compilation.
27+
debug_options: [DebugOption];
28+
}
29+
30+
root_type EnnExecuTorchOptions;
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright (c) 2025 Samsung Electronics Co. LTD
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
from typing import Any, Dict, List, Optional, Union
9+
10+
import executorch.backends.samsung.python.PyGraphWrapperAdaptor as PyGraphWrapper
11+
12+
import numpy as np
13+
14+
import torch
15+
16+
class EnnGraph:
17+
def __init__(self):
18+
# default
19+
self.version = "0.6.0"
20+
21+
def init(self, name: str, soc_name):
22+
self.name = name
23+
self.soc_name = soc_name
24+
self.graph = PyGraphWrapper.PyEnnGraphWrapper()
25+
self.graph.Init()
26+
27+
self.inputs = []
28+
self.outputs = []
29+
30+
def define_op(
31+
self,
32+
name,
33+
type,
34+
input_ids: List[int],
35+
output_ids: List[int],
36+
params: Optional[Dict] = None,
37+
):
38+
op = PyGraphWrapper.PyEnnOpWrapper(name, type, input_ids, output_ids)
39+
40+
if params is not None:
41+
assert isinstance(params, dict), "Please pass op params as dict type."
42+
for key in params:
43+
py_param_wrapper = PyGraphWrapper.OpParamWrapper(key)
44+
if isinstance(params[key], (list, tuple)):
45+
py_param_wrapper.SetVectorValue(params[key])
46+
elif isinstance(params[key], str):
47+
py_param_wrapper.SetStringValue(params[key])
48+
elif isinstance(params[key], (int, float, bool)):
49+
py_param_wrapper.SetScalarValue(params[key])
50+
else:
51+
logging.error("Unsupported param type.")
52+
# Set
53+
op.AddOpParam(py_param_wrapper)
54+
55+
self.graph.DefineOpNode(op)
56+
57+
def define_tensor( # noqa: C901
58+
self,
59+
name: str,
60+
shape: List,
61+
data_type: str,
62+
tensor_type: str,
63+
data: Optional[Union[np.ndarray, torch.Tensor]] = None,
64+
quant_param: Optional[Dict[str, Any]] = None,
65+
) -> int:
66+
layout = "NCHW" if len(shape) == 4 else "UNDEFINED"
67+
68+
tensor = PyGraphWrapper.PyEnnTensorWrapper(name, shape, data_type, layout)
69+
70+
if data is not None:
71+
if isinstance(data, torch.Tensor):
72+
data = data.detach().numpy()
73+
tensor.AddData(data)
74+
75+
id = self.graph.DefineTensor(tensor)
76+
77+
if tensor_type == "INPUT":
78+
self.inputs.append(id)
79+
elif tensor_type == "OUTPUT":
80+
self.outputs.append(id)
81+
82+
return id
83+
84+
def finish(self):
85+
self.graph.SetGraphInputTensors(self.inputs)
86+
self.graph.SetGraphOutputTensors(self.outputs)
87+
self.graph.FinishBuild()
88+
89+
def serialize(self):
90+
return self.graph.Serialize()
91+
92+
@staticmethod
93+
def _affine_meta_param(param: Any) -> str:
94+
type_str_affine_table = {
95+
torch.int32: "FLOAT32", # INT32 just used for HW quant.
96+
}
97+
if isinstance(param, str):
98+
return param
99+
if isinstance(param, (float, int)):
100+
return [param]
101+
if hasattr(param, "tolist"):
102+
return param.tolist()
103+
if isinstance(param, torch.dtype):
104+
# Convenient for debugging
105+
param = type_str_affine_table.get(param, "")
106+
107+
return param

0 commit comments

Comments
 (0)