Skip to content

Commit 7be65ad

Browse files
authored
[circle-resizer] Add CircleModel (#15185)
This commit adds new CircleModel class responsible for loading and processing model in Circle model. The additional features are extraction information about input/output shapes and saving the current version to file/stream. ONE-DCO-1.0-Signed-off-by: Mateusz Bencer m.bencer@partner.samsung.com
1 parent 67e8fd7 commit 7be65ad

File tree

6 files changed

+409
-4
lines changed

6 files changed

+409
-4
lines changed
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/*
2+
* Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#ifndef __CIRCLE_RESIZER_CIRCLE_MODEL_H__
18+
#define __CIRCLE_RESIZER_CIRCLE_MODEL_H__
19+
20+
#include "Shape.h"
21+
22+
#include <string>
23+
#include <memory>
24+
#include <vector>
25+
26+
namespace luci
27+
{
28+
class Module;
29+
}
30+
31+
namespace circle_resizer
32+
{
33+
34+
/**
35+
* The representation of Circle Model.
36+
*/
37+
class CircleModel
38+
{
39+
public:
40+
/**
41+
* @brief Initialize the model with buffer representation.
42+
*
43+
* Exceptions:
44+
* - std::runtime_error if interpretation of provided buffer as a circle model failed.
45+
*/
46+
explicit CircleModel(const std::vector<uint8_t> &buffer);
47+
48+
/**
49+
* @brief Initialize the model with buffer representation.
50+
*
51+
* Exceptions:
52+
* - std::runtime_error if reading a model from provided path failed.
53+
*/
54+
explicit CircleModel(const std::string &model_path);
55+
56+
/**
57+
* @brief Dtor of CircleModel. Note that explicit declaration is needed to satisfy forward
58+
* declaration + unique_ptr.
59+
*/
60+
~CircleModel();
61+
62+
/**
63+
* @brief Get the loaded model in luci::Module representation.
64+
*/
65+
luci::Module *module();
66+
67+
/**
68+
* @brief Get input shapes of the loaded model.
69+
*/
70+
std::vector<Shape> input_shapes() const;
71+
72+
/**
73+
* @brief Get output shapes of the loaded model.
74+
*
75+
*/
76+
std::vector<Shape> output_shapes() const;
77+
78+
/**
79+
* @brief Save the model to the output stream.
80+
*
81+
* Exceptions:
82+
* - std::runtime_error if saving the model the given stream failed.
83+
*/
84+
void save(std::ostream &stream);
85+
86+
/**
87+
* @brief Save the model to the location indicated by output_path.
88+
*
89+
* Exceptions:
90+
* - std::runtime_error if saving the model the given path failed.
91+
*/
92+
void save(const std::string &output_path);
93+
94+
private:
95+
std::unique_ptr<luci::Module> _module;
96+
};
97+
98+
} // namespace circle_resizer
99+
100+
#endif // __CIRCLE_RESIZER_CIRCLE_MODEL_H__
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
require("arser")
2+
require("common-artifacts")
3+
require("mio-circle08")
24
require("safemain")
35
require("vconone")
Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
1-
list(APPEND CIRCLE_RESIZER_CORE_SOURCES Dim.cpp)
2-
list(APPEND CIRCLE_RESIZER_CORE_SOURCES Shape.cpp)
3-
list(APPEND CIRCLE_RESIZER_CORE_SOURCES ShapeParser.cpp)
1+
list(APPEND CIRCLE_RESIZER_SOURCES Dim.cpp)
2+
list(APPEND CIRCLE_RESIZER_SOURCES Shape.cpp)
3+
list(APPEND CIRCLE_RESIZER_SOURCES ShapeParser.cpp)
4+
list(APPEND CIRCLE_RESIZER_SOURCES CircleModel.cpp)
45

5-
add_library(circle_resizer_core STATIC "${CIRCLE_RESIZER_CORE_SOURCES}")
6+
add_library(circle_resizer_core SHARED "${CIRCLE_RESIZER_SOURCES}")
67

78
target_include_directories(circle_resizer_core PUBLIC ../include)
89

10+
target_link_libraries(circle_resizer_core PRIVATE luci_export)
11+
target_link_libraries(circle_resizer_core PRIVATE luci_import)
12+
target_link_libraries(circle_resizer_core PRIVATE luci_lang)
13+
target_link_libraries(circle_resizer_core PRIVATE mio_circle08)
14+
915
install(TARGETS circle_resizer_core DESTINATION lib)
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
/*
2+
* Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "CircleModel.h"
18+
19+
#include <mio/circle/schema_generated.h>
20+
21+
#include <luci/Importer.h>
22+
#include <luci/CircleExporter.h>
23+
#include <luci/CircleFileExpContract.h>
24+
25+
#include <fstream>
26+
27+
using namespace circle_resizer;
28+
29+
namespace
30+
{
31+
32+
std::vector<uint8_t> read_model(const std::string &model_path)
33+
{
34+
std::ifstream file_stream(model_path, std::ios::in | std::ios::binary | std::ifstream::ate);
35+
if (!file_stream.is_open())
36+
{
37+
throw std::runtime_error("Failed to open file: " + model_path);
38+
}
39+
40+
std::streamsize size = file_stream.tellg();
41+
file_stream.seekg(0, std::ios::beg);
42+
43+
std::vector<uint8_t> buffer(size);
44+
if (!file_stream.read(reinterpret_cast<char *>(buffer.data()), size))
45+
{
46+
throw std::runtime_error("Failed to read file: " + model_path);
47+
}
48+
49+
return buffer;
50+
}
51+
52+
std::unique_ptr<luci::Module> load_module(const std::vector<uint8_t> &model_buffer)
53+
{
54+
flatbuffers::Verifier verifier{model_buffer.data(), model_buffer.size()};
55+
if (!circle::VerifyModelBuffer(verifier))
56+
{
57+
throw std::runtime_error("Verification of the model failed");
58+
}
59+
60+
const luci::GraphBuilderSource *source_ptr = &luci::GraphBuilderRegistry::get();
61+
luci::Importer importer(source_ptr);
62+
return importer.importModule(model_buffer.data(), model_buffer.size());
63+
}
64+
65+
class BufferModelContract : public luci::CircleExporter::Contract
66+
{
67+
public:
68+
BufferModelContract(luci::Module *module)
69+
: _module(module), _buffer{std::make_unique<std::vector<uint8_t>>()}
70+
{
71+
assert(_module); // FIX_CALLER_UNLESS
72+
}
73+
74+
luci::Module *module() const override { return _module; }
75+
76+
bool store(const char *ptr, const size_t size) const override
77+
{
78+
_buffer->resize(size);
79+
std::copy(ptr, ptr + size, _buffer->begin());
80+
return true;
81+
}
82+
83+
std::vector<uint8_t> get_buffer() { return *_buffer; }
84+
85+
private:
86+
luci::Module *_module;
87+
std::unique_ptr<std::vector<uint8_t>> _buffer; // note that the store method has to be const
88+
};
89+
90+
template <typename NodeType>
91+
std::vector<Shape> extract_shapes(const std::vector<loco::Node *> &nodes)
92+
{
93+
std::vector<Shape> shapes;
94+
for (const auto &loco_node : nodes)
95+
{
96+
std::vector<Dim> dims;
97+
const auto circle_node = loco::must_cast<const NodeType *>(loco_node);
98+
for (uint32_t dim_idx = 0; dim_idx < circle_node->rank(); dim_idx++)
99+
{
100+
if (circle_node->dim(dim_idx).known())
101+
{
102+
const int32_t dim_val = circle_node->dim(dim_idx).value();
103+
dims.push_back(Dim{dim_val});
104+
}
105+
else
106+
{
107+
dims.push_back(Dim{-1});
108+
}
109+
}
110+
shapes.push_back(Shape{dims});
111+
}
112+
return shapes;
113+
}
114+
115+
} // namespace
116+
117+
CircleModel::CircleModel(const std::vector<uint8_t> &buffer) : _module{load_module(buffer)} {}
118+
119+
CircleModel::CircleModel(const std::string &model_path) : CircleModel(read_model(model_path)) {}
120+
121+
luci::Module *CircleModel::module() { return _module.get(); }
122+
123+
void CircleModel::save(std::ostream &stream)
124+
{
125+
BufferModelContract contract(module());
126+
luci::CircleExporter exporter;
127+
if (!exporter.invoke(&contract))
128+
{
129+
throw std::runtime_error("Exporting buffer from the model failed");
130+
}
131+
132+
auto model_buffer = contract.get_buffer();
133+
stream.write(reinterpret_cast<const char *>(model_buffer.data()), model_buffer.size());
134+
if (!stream.good())
135+
{
136+
throw std::runtime_error("Failed to write to output stream");
137+
}
138+
}
139+
140+
void CircleModel::save(const std::string &output_path)
141+
{
142+
std::ofstream out_stream(output_path, std::ios::out | std::ios::binary);
143+
save(out_stream);
144+
}
145+
146+
std::vector<Shape> CircleModel::input_shapes() const
147+
{
148+
return extract_shapes<luci::CircleInput>(loco::input_nodes(_module->graph()));
149+
}
150+
151+
std::vector<Shape> CircleModel::output_shapes() const
152+
{
153+
return extract_shapes<luci::CircleOutput>(loco::output_nodes(_module->graph()));
154+
}
155+
156+
CircleModel::~CircleModel() = default;

compiler/circle-resizer/tests/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@ endif(NOT ENABLE_TEST)
44

55
list(APPEND CIRCLE_RESIZER_TEST_SOURCES Shape.test.cpp)
66
list(APPEND CIRCLE_RESIZER_TEST_SOURCES ShapeParser.test.cpp)
7+
list(APPEND CIRCLE_RESIZER_TEST_SOURCES CircleModel.test.cpp)
78

89
nnas_find_package(GTest REQUIRED)
910
GTest_AddTest(circle_resizer_unit_test ${CIRCLE_RESIZER_TEST_SOURCES})
1011
target_link_libraries(circle_resizer_unit_test circle_resizer_core)
12+
13+
get_target_property(ARTIFACTS_PATH testDataGenerator BINARY_DIR)
14+
set_tests_properties(circle_resizer_unit_test
15+
PROPERTIES
16+
ENVIRONMENT "ARTIFACTS_PATH=${ARTIFACTS_PATH}")

0 commit comments

Comments
 (0)