Skip to content

Commit e50bee4

Browse files
committed
Add wrappers for samsung backend
Add Op wrapper and tensor wrapper for Converter of LiteCore. Also, include the op param wrapper and quantize param wrapper. Signed-off-by: chong-chen <[email protected]> Signed-off-by: jiseong.oh <[email protected]>
1 parent df30f69 commit e50bee4

File tree

8 files changed

+606
-2
lines changed

8 files changed

+606
-2
lines changed

backends/samsung/aot/CMakeLists.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,8 @@ target_sources(
88
PyEnnWrapperAdaptor PUBLIC PyEnnWrapperAdaptor.cpp
99
PyEnnWrapperAdaptor.h
1010
)
11-
11+
target_sources(
12+
PyGraphWrapperAdaptor PUBLIC PyGraphWrapperAdaptor.cpp
13+
PyGraphWrapperAdaptor.h wrappers/op_param_wrapper.h
14+
wrappers/op_wrapper.h wrappers/tensor_wrapper.h
15+
)
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/*
2+
* Copyright (c) 2025 Samsung Electronics Co. LTD
3+
* All rights reserved
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*
8+
*/
9+
10+
#include "PyGraphWrapperAdaptor.h"
11+
12+
namespace torch {
13+
namespace executor {
14+
namespace enn {
15+
16+
PYBIND11_MODULE(PyGraphWrapperAdaptor, m) {
17+
pybind11::class_<OpParamWrapper, std::shared_ptr<OpParamWrapper>>(
18+
m, "OpParamWrapper")
19+
.def(pybind11::init<std::string>())
20+
.def("SetStringValue", &OpParamWrapper::SetStringValue)
21+
.def("SetScalarValue", &OpParamWrapper::SetScalarValue<double>)
22+
.def("SetScalarValue", &OpParamWrapper::SetScalarValue<float>)
23+
.def("SetScalarValue", &OpParamWrapper::SetScalarValue<bool>)
24+
.def("SetScalarValue", &OpParamWrapper::SetScalarValue<uint32_t>)
25+
.def("SetScalarValue", &OpParamWrapper::SetScalarValue<int32_t>)
26+
.def("SetScalarValue", &OpParamWrapper::SetScalarValue<uint64_t>)
27+
.def("SetScalarValue", &OpParamWrapper::SetScalarValue<int64_t>)
28+
.def("SetVectorValue", &OpParamWrapper::SetVectorValue<double>)
29+
.def("SetVectorValue", &OpParamWrapper::SetVectorValue<float>)
30+
.def("SetVectorValue", &OpParamWrapper::SetVectorValue<uint32_t>)
31+
.def("SetVectorValue", &OpParamWrapper::SetVectorValue<int32_t>)
32+
.def("SetVectorValue", &OpParamWrapper::SetVectorValue<uint64_t>)
33+
.def("SetVectorValue", &OpParamWrapper::SetVectorValue<int64_t>);
34+
35+
pybind11::class_<EnnTensorWrapper, std::shared_ptr<EnnTensorWrapper>>(
36+
m, "PyEnnTensorWrapper")
37+
.def(pybind11::init<
38+
std::string,
39+
const std::vector<DIM_T>&,
40+
std::string,
41+
std::string>())
42+
.def(
43+
"AddQuantizeParam",
44+
&EnnTensorWrapper::AddQuantizeParam,
45+
"Add quantize parameter.")
46+
.def(
47+
"AddData",
48+
&EnnTensorWrapper::AddData,
49+
"Add data for constant tensor.");
50+
51+
pybind11::class_<EnnOpWrapper, std::shared_ptr<EnnOpWrapper>>(
52+
m, "PyEnnOpWrapper")
53+
.def(pybind11::init<
54+
std::string,
55+
std::string,
56+
const std::vector<TENSOR_ID_T>&,
57+
const std::vector<TENSOR_ID_T>&>())
58+
.def(
59+
"AddOpParam",
60+
&EnnOpWrapper::AddOpParam,
61+
"Add parameter for current op.");
62+
63+
pybind11::class_<PyEnnGraphWrapper, std::shared_ptr<PyEnnGraphWrapper>>(
64+
m, "PyEnnGraphWrapper")
65+
.def(pybind11::init())
66+
.def("Init", &PyEnnGraphWrapper::Init, "Initialize Graph Wrapper.")
67+
.def(
68+
"DefineTensor",
69+
&PyEnnGraphWrapper::DefineTensor,
70+
"Define a tensor in graph.")
71+
.def(
72+
"DefineOpNode",
73+
&PyEnnGraphWrapper::DefineOpNode,
74+
"Define a op node in graph.")
75+
.def(
76+
"SetGraphInputTensors",
77+
&PyEnnGraphWrapper::SetGraphInputTensors,
78+
"Set inputs for Graph")
79+
.def(
80+
"SetGraphOutputTensors",
81+
&PyEnnGraphWrapper::SetGraphOutputTensors,
82+
"Set outputs for Graph")
83+
.def(
84+
"FinishBuild",
85+
&PyEnnGraphWrapper::FinishBuild,
86+
"Finish to build the graph.")
87+
.def("Serialize", &PyEnnGraphWrapper::Serialize, "Serialize the graph.");
88+
}
89+
90+
} // namespace enn
91+
} // namespace executor
92+
} // namespace torch
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
/*
2+
* Copyright (c) 2025 Samsung Electronics Co. LTD
3+
* All rights reserved
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*
8+
*/
9+
#pragma once
10+
11+
#include <include/common-types.h>
12+
#include <include/graph_wrapper_api.h>
13+
#include <pybind11/numpy.h>
14+
#include <pybind11/pybind11.h>
15+
#include <pybind11/stl.h>
16+
#include <exception>
17+
#include <iostream>
18+
19+
#include "wrappers/op_param_wrapper.h"
20+
#include "wrappers/op_wrapper.h"
21+
#include "wrappers/tensor_wrapper.h"
22+
23+
namespace py = pybind11;
24+
25+
namespace torch {
26+
namespace executor {
27+
namespace enn {
28+
29+
class PyEnnGraphWrapper {
30+
public:
31+
PyEnnGraphWrapper() {}
32+
33+
void Init() {
34+
graph_wrapper_ = create_graph("");
35+
}
36+
37+
TENSOR_ID_T DefineTensor(std::shared_ptr<EnnTensorWrapper> tensor) const {
38+
TENSOR_ID_T tensor_id;
39+
auto result = define_tensor(
40+
graph_wrapper_,
41+
&tensor_id,
42+
tensor->GetName().c_str(),
43+
tensor->GetShape().data(),
44+
tensor->GetShape().size(),
45+
tensor->GetDataType().c_str(),
46+
tensor->GetLayout().c_str());
47+
if (result != GraphWrapperReturn::SUCCESS) {
48+
throw std::runtime_error("fail in define tensor");
49+
}
50+
51+
if (tensor->HasConstantData()) {
52+
auto set_data_result = set_data_for_constant_tensor(
53+
graph_wrapper_,
54+
tensor_id,
55+
tensor->GetDataRawPtr(),
56+
tensor->GetDataBytes());
57+
if (set_data_result != GraphWrapperReturn::SUCCESS) {
58+
throw std::runtime_error("fail in define tensor");
59+
}
60+
}
61+
62+
auto* quantize_param = tensor->GetQuantizeParam();
63+
if (quantize_param != nullptr) {
64+
auto set_qparam_result = set_quantize_param_for_tensor(
65+
graph_wrapper_,
66+
tensor_id,
67+
quantize_param->GetQuantizeType().c_str(),
68+
quantize_param->GetScales(),
69+
quantize_param->GetZeroPoints());
70+
if (set_qparam_result != GraphWrapperReturn::SUCCESS) {
71+
throw std::runtime_error("fail in define tensor");
72+
}
73+
}
74+
75+
return tensor_id;
76+
}
77+
78+
NODE_ID_T DefineOpNode(std::shared_ptr<EnnOpWrapper> op) const {
79+
NODE_ID_T op_id;
80+
81+
auto result = define_op_node(
82+
graph_wrapper_,
83+
&op_id,
84+
op->GetName().c_str(),
85+
op->GetType().c_str(),
86+
op->GetInputs().data(),
87+
op->GetInputs().size(),
88+
op->GetOutputs().data(),
89+
op->GetOutputs().size());
90+
if (result != GraphWrapperReturn::SUCCESS) {
91+
throw std::runtime_error("fail in define op");
92+
}
93+
94+
for (const auto& param : op->GetParams()) {
95+
add_op_parameter(
96+
graph_wrapper_, op_id, param->getKeyName().c_str(), param->Dump());
97+
}
98+
99+
return op_id;
100+
}
101+
102+
void SetGraphInputTensors(const std::vector<TENSOR_ID_T>& tensors) const {
103+
auto result =
104+
set_graph_input_tensors(graph_wrapper_, tensors.data(), tensors.size());
105+
if (result != GraphWrapperReturn::SUCCESS) {
106+
throw std::runtime_error("fail in set graph inputs");
107+
}
108+
}
109+
110+
void SetGraphOutputTensors(const std::vector<TENSOR_ID_T>& tensors) const {
111+
auto result = set_graph_output_tensors(
112+
graph_wrapper_, tensors.data(), tensors.size());
113+
if (result != GraphWrapperReturn::SUCCESS) {
114+
throw std::runtime_error("fail in set graph outputs");
115+
}
116+
}
117+
118+
void FinishBuild() const {
119+
auto result = finish_build_graph(graph_wrapper_);
120+
121+
if (result != GraphWrapperReturn::SUCCESS) {
122+
throw std::runtime_error("fail to build graph");
123+
}
124+
}
125+
126+
py::array_t<char> Serialize() {
127+
uint64_t nbytes = 0;
128+
uint8_t* addr = nullptr;
129+
auto result = serialize(graph_wrapper_, &addr, &nbytes);
130+
131+
if (result != GraphWrapperReturn::SUCCESS || addr == nullptr ||
132+
nbytes == 0) {
133+
throw std::runtime_error("fail to serialize");
134+
}
135+
136+
auto serial_buf = py::array_t<char>(nbytes);
137+
auto serial_buf_block = serial_buf.request();
138+
char* serial_buf_ptr = (char*)serial_buf_block.ptr;
139+
std::memcpy(serial_buf_ptr, addr, nbytes);
140+
141+
return serial_buf;
142+
}
143+
144+
~PyEnnGraphWrapper() {
145+
release_graph(graph_wrapper_);
146+
}
147+
148+
private:
149+
GraphHandler graph_wrapper_;
150+
};
151+
152+
} // namespace enn
153+
} // namespace executor
154+
} // namespace torch
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
/*
2+
* Copyright (c) 2025 Samsung Electronics Co. LTD
3+
* All rights reserved
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*
8+
*/
9+
#pragma once
10+
11+
#include <stdint.h>
12+
#include <cstring>
13+
#include <iostream>
14+
#include <memory>
15+
#include <string>
16+
#include <vector>
17+
18+
#include <include/common-types.h>
19+
20+
namespace torch {
21+
namespace executor {
22+
namespace enn {
23+
24+
template <class T>
25+
struct ScalarTypeCast {
26+
constexpr static ScalarType value = ScalarType::UNKNOWN;
27+
};
28+
29+
template <>
30+
struct ScalarTypeCast<uint64_t> {
31+
constexpr static ScalarType value = ScalarType::UINT64;
32+
};
33+
34+
template <>
35+
struct ScalarTypeCast<int64_t> {
36+
constexpr static ScalarType value = ScalarType::INT64;
37+
};
38+
39+
template <>
40+
struct ScalarTypeCast<uint32_t> {
41+
constexpr static ScalarType value = ScalarType::UINT32;
42+
};
43+
44+
template <>
45+
struct ScalarTypeCast<int32_t> {
46+
constexpr static ScalarType value = ScalarType::INT32;
47+
};
48+
49+
template <>
50+
struct ScalarTypeCast<float> {
51+
constexpr static ScalarType value = ScalarType::FLOAT32;
52+
};
53+
54+
template <>
55+
struct ScalarTypeCast<double> {
56+
constexpr static ScalarType value = ScalarType::FLOAT64;
57+
};
58+
59+
template <>
60+
struct ScalarTypeCast<bool> {
61+
constexpr static ScalarType value = ScalarType::BOOL;
62+
};
63+
64+
class OpParamWrapper {
65+
public:
66+
OpParamWrapper(std::string key) : key_name_(std::move(key)) {}
67+
68+
~OpParamWrapper() = default;
69+
70+
std::string getKeyName() const {
71+
return key_name_;
72+
}
73+
74+
template <typename T>
75+
void SetScalarValue(T value) {
76+
auto bytes_ = sizeof(T);
77+
storage_ = std::unique_ptr<uint8_t[]>(new uint8_t[bytes_]);
78+
memcpy(storage_.get(), &value, bytes_);
79+
size_ = 1;
80+
is_scalar_ = true;
81+
scalar_type_ = ScalarTypeCast<T>::value;
82+
}
83+
84+
template <typename T>
85+
void SetVectorValue(const std::vector<T>& value) {
86+
auto bytes_ = sizeof(T) * value.size();
87+
storage_ = std::unique_ptr<uint8_t[]>(new uint8_t[bytes_]);
88+
memcpy(storage_.get(), value.data(), bytes_);
89+
size_ = value.size();
90+
is_scalar_ = false;
91+
scalar_type_ = ScalarTypeCast<T>::value;
92+
}
93+
94+
void SetStringValue(const std::string& value) {
95+
auto bytes_ = sizeof(std::string::value_type) * value.size();
96+
storage_ = std::unique_ptr<uint8_t[]>(new uint8_t[bytes_]);
97+
memcpy(storage_.get(), value.data(), bytes_);
98+
size_ = value.size();
99+
is_scalar_ = false;
100+
scalar_type_ = ScalarType::CHAR;
101+
}
102+
103+
ParamWrapper Dump() const {
104+
ParamWrapper param;
105+
param.data = storage_.get();
106+
param.size = size_;
107+
param.is_scalar = is_scalar_;
108+
param.type = scalar_type_;
109+
110+
return param;
111+
}
112+
113+
private:
114+
std::string key_name_;
115+
std::unique_ptr<uint8_t[]> storage_ = nullptr;
116+
uint32_t size_ = 0;
117+
uint32_t bytes_ = 0;
118+
bool is_scalar_ = false;
119+
ScalarType scalar_type_ = ScalarType::UNKNOWN;
120+
};
121+
122+
} // namespace enn
123+
} // namespace executor
124+
} // namespace torch

0 commit comments

Comments
 (0)