diff --git a/compiler/circle-resizer/include/ModelData.h b/compiler/circle-resizer/include/ModelData.h new file mode 100644 index 00000000000..8e900b0969c --- /dev/null +++ b/compiler/circle-resizer/include/ModelData.h @@ -0,0 +1,121 @@ +/* + * Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CIRCLE_RESIZER_MODEL_DATA_H__ +#define __CIRCLE_RESIZER_MODEL_DATA_H__ + +#include "Shape.h" + +#include +#include +#include + +namespace luci +{ +class Module; +} + +namespace circle_resizer +{ + +/** + * The representation of Circle Model. + * The purpose of the class is to keep the buffer and the module representation of the model + * synchronized. + */ +class ModelData +{ +public: + /** + * @brief Initialize the model with buffer representation. + * + * Exceptions: + * - std::runtime_error if interpretation of provided buffer as a circle model failed. + */ + explicit ModelData(const std::vector &buffer); + + /** + * @brief Initialize the model with buffer representation. + * + * Exceptions: + * - std::runtime_error if reading a model from provided path failed. + */ + explicit ModelData(const std::string &model_path); + + /** + * @brief Dtor of ModelData. Note that explicit declaration is needed to satisfy forward + * declaration + unique_ptr. + */ + ~ModelData(); + + /** + * @brief Notify that the buffer representation of the model has been modified so the module is no + * more valid. + */ + void invalidate_module(); + + /** + * @brief Notify that the module representation of the model has been modified so the buffer is no + * more valid. + */ + void invalidate_buffer(); + + /** + * @brief Get the loaded model as the buffer. + */ + std::vector &buffer(); + + /** + * @brief Get the loaded model as the module. + */ + luci::Module *module(); + + /** + * @brief Get input shapes of the loaded model. + */ + std::vector input_shapes(); + + /** + * @brief Get output shapes of the loaded model. + * + */ + std::vector output_shapes(); + + /** + * @brief Save the loaded model to the stream. + * + * Exceptions: + * - std::runtime_error if saving the model the given stream failed. + */ + void save(std::ostream &stream); + + /** + * @brief Save the loaded model to the location indicated by output_path. + * + * Exceptions: + * - std::runtime_error if saving the model the given path failed. + */ + void save(const std::string &output_path); + +private: + bool _module_invalidated = false, _buffer_invalidated = false; + std::vector _buffer; + std::unique_ptr _module; +}; + +} // namespace circle_resizer + +#endif // __CIRCLE_RESIZER_MODEL_DATA_H__ diff --git a/compiler/circle-resizer/requires.cmake b/compiler/circle-resizer/requires.cmake index 8e48764f58f..7c0ade30d58 100644 --- a/compiler/circle-resizer/requires.cmake +++ b/compiler/circle-resizer/requires.cmake @@ -1,3 +1,5 @@ require("arser") +require("common-artifacts") +require("mio-circle08") require("safemain") require("vconone") diff --git a/compiler/circle-resizer/src/CMakeLists.txt b/compiler/circle-resizer/src/CMakeLists.txt index 03f5a2dba14..532e8b6fdb1 100644 --- a/compiler/circle-resizer/src/CMakeLists.txt +++ b/compiler/circle-resizer/src/CMakeLists.txt @@ -1,9 +1,14 @@ -list(APPEND CIRCLE_RESIZER_CORE_SOURCES Dim.cpp) -list(APPEND CIRCLE_RESIZER_CORE_SOURCES Shape.cpp) -list(APPEND CIRCLE_RESIZER_CORE_SOURCES ShapeParser.cpp) +list(APPEND CIRCLE_RESIZER_SOURCES Dim.cpp) +list(APPEND CIRCLE_RESIZER_SOURCES Shape.cpp) +list(APPEND CIRCLE_RESIZER_SOURCES ShapeParser.cpp) +list(APPEND CIRCLE_RESIZER_SOURCES ModelData.cpp) -add_library(circle_resizer_core STATIC "${CIRCLE_RESIZER_CORE_SOURCES}") +add_library(circle_resizer_core SHARED "${CIRCLE_RESIZER_SOURCES}") target_include_directories(circle_resizer_core PUBLIC ../include) +target_link_libraries(circle_resizer_core PRIVATE mio_circle08) +target_link_libraries(circle_resizer_core PRIVATE luci_export) +target_link_libraries(circle_resizer_core PRIVATE luci_import) + install(TARGETS circle_resizer_core DESTINATION lib) diff --git a/compiler/circle-resizer/src/ModelData.cpp b/compiler/circle-resizer/src/ModelData.cpp new file mode 100644 index 00000000000..ad9c7690415 --- /dev/null +++ b/compiler/circle-resizer/src/ModelData.cpp @@ -0,0 +1,180 @@ +/* + * Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ModelData.h" + +#include + +#include +#include +#include + +#include +#include + +using namespace circle_resizer; + +namespace +{ +std::vector read_model(const std::string &model_path) +{ + std::ifstream file_stream(model_path, std::ios::in | std::ios::binary | std::ifstream::ate); + if (!file_stream.is_open()) + { + throw std::runtime_error("Failed to open file: " + model_path); + } + + std::streamsize size = file_stream.tellg(); + file_stream.seekg(0, std::ios::beg); + + std::vector buffer(size); + if (!file_stream.read(reinterpret_cast(buffer.data()), size)) + { + throw std::runtime_error("Failed to read file: " + model_path); + } + + return buffer; +} + +std::unique_ptr load_module(const std::vector &model_buffer) +{ + flatbuffers::Verifier verifier{model_buffer.data(), model_buffer.size()}; + if (!circle::VerifyModelBuffer(verifier)) + { + throw std::runtime_error("Verification of the model failed"); + } + + const luci::GraphBuilderSource *source_ptr = &luci::GraphBuilderRegistry::get(); + luci::Importer importer(source_ptr); + return importer.importModule(model_buffer.data(), model_buffer.size()); +} + +class BufferModelContract : public luci::CircleExporter::Contract +{ +public: + BufferModelContract(luci::Module *module) + : _module(module), _buffer{std::make_unique>()} + { + } + + luci::Module *module() const override { return _module; } + + bool store(const char *ptr, const size_t size) const override + { + _buffer->resize(size); + std::copy(ptr, ptr + size, _buffer->begin()); + return true; + } + + std::vector get_buffer() { return *_buffer; } + +private: + luci::Module *_module; + std::unique_ptr> _buffer; +}; + +template +std::vector extract_shapes(const std::vector &nodes) +{ + std::vector shapes; + for (const auto &loco_node : nodes) + { + std::vector dims; + const auto circle_node = loco::must_cast(loco_node); + for (uint32_t dim_idx = 0; dim_idx < circle_node->rank(); dim_idx++) + { + if (circle_node->dim(dim_idx).known()) + { + const int32_t dim_val = circle_node->dim(dim_idx).value(); + dims.push_back(Dim{dim_val}); + } + else + { + dims.push_back(Dim{-1}); + } + } + shapes.push_back(Shape{dims}); + } + return shapes; +} + +} // namespace + +ModelData::ModelData(const std::vector &buffer) + : _buffer{buffer}, _module{load_module(buffer)} +{ +} + +ModelData::ModelData(const std::string &model_path) : ModelData(read_model(model_path)) {} + +void ModelData::invalidate_module() { _module_invalidated = true; } + +void ModelData::invalidate_buffer() { _buffer_invalidated = true; } + +std::vector &ModelData::buffer() +{ + if (_buffer_invalidated) + { + luci::CircleExporter exporter; + BufferModelContract contract(module()); + + if (!exporter.invoke(&contract)) + { + throw std::runtime_error("Exporting buffer from the model failed"); + } + _buffer = contract.get_buffer(); + _buffer_invalidated = false; + } + return _buffer; +} + +luci::Module *ModelData::module() +{ + if (_module_invalidated) + { + _module = load_module(_buffer); + _module_invalidated = false; + } + return _module.get(); +} + +void ModelData::save(std::ostream &stream) +{ + auto &buff = buffer(); + stream.write(reinterpret_cast(buff.data()), buff.size()); + if (!stream.good()) + { + throw std::runtime_error("Failed to write to output stream"); + } +} + +void ModelData::save(const std::string &output_path) +{ + std::ofstream out_stream(output_path, std::ios::out | std::ios::binary); + save(out_stream); +} + +std::vector ModelData::input_shapes() +{ + return extract_shapes(loco::input_nodes(module()->graph())); +} + +std::vector ModelData::output_shapes() +{ + return extract_shapes(loco::output_nodes(module()->graph())); +} + +ModelData::~ModelData() = default; diff --git a/compiler/circle-resizer/tests/CMakeLists.txt b/compiler/circle-resizer/tests/CMakeLists.txt index 6f450467537..a8a0ac2d38d 100644 --- a/compiler/circle-resizer/tests/CMakeLists.txt +++ b/compiler/circle-resizer/tests/CMakeLists.txt @@ -4,7 +4,15 @@ endif(NOT ENABLE_TEST) list(APPEND CIRCLE_RESIZER_TEST_SOURCES Shape.test.cpp) list(APPEND CIRCLE_RESIZER_TEST_SOURCES ShapeParser.test.cpp) +list(APPEND CIRCLE_RESIZER_TEST_SOURCES ModelData.test.cpp) nnas_find_package(GTest REQUIRED) GTest_AddTest(circle_resizer_unit_test ${CIRCLE_RESIZER_TEST_SOURCES}) target_link_libraries(circle_resizer_unit_test circle_resizer_core) +target_link_libraries(circle_resizer_unit_test mio_circle08) +target_link_libraries(circle_resizer_unit_test luci_lang) + +get_target_property(ARTIFACTS_PATH testDataGenerator BINARY_DIR) +set_tests_properties(circle_resizer_unit_test + PROPERTIES + ENVIRONMENT "ARTIFACTS_PATH=${ARTIFACTS_PATH}") diff --git a/compiler/circle-resizer/tests/ModelData.test.cpp b/compiler/circle-resizer/tests/ModelData.test.cpp new file mode 100644 index 00000000000..1392ab99af4 --- /dev/null +++ b/compiler/circle-resizer/tests/ModelData.test.cpp @@ -0,0 +1,206 @@ +/* + * Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ModelData.h" + +#include "luci/IR/Module.h" +#include "loco/IR/Graph.h" + +#include +#include + +#include + +#include + +using namespace circle_resizer; +using ::testing::HasSubstr; + +namespace +{ + +bool compare_shapes(const std::vector ¤t, const std::vector &expected) +{ + if (current.size() != expected.size()) + { + return false; + } + for (size_t i = 0; i < current.size(); ++i) + { + if (!(current[i] == expected[i])) + { + return false; + } + } + return true; +} + +std::string extract_subgraph_name(const std::vector &buffer) +{ + auto model = circle::GetModel(buffer.data()); + if (model) + { + auto subgraphs = model->subgraphs(); + if (subgraphs->size() > 0) + { + auto subgraph = subgraphs->Get(0); + if (subgraph->name()->c_str()) + { + return subgraph->name()->c_str(); + } + } + } + return ""; +} + +// change the first subgraph name using buffer as an input +bool change_subgraph_name(std::vector &buffer, const std::string &name) +{ + auto model = circle::GetMutableModel(buffer.data()); + if (!model) + { + return false; + } + auto subgraphs = model->mutable_subgraphs(); + auto subgraph = subgraphs->GetMutableObject(0); + if (subgraph->name()->size() != name.size()) + { + return false; + } + for (size_t i = 0; i < name.size(); ++i) + { + subgraph->mutable_name()->Mutate(i, name[i]); + } + return true; +} + +// change the first subgraph name using loco::Graph as an input +void change_subgraph_name(loco::Graph *graph, const std::string &name) { graph->name(name); } + +} // namespace + +class ModelDataTest : public ::testing::Test +{ +protected: + void SetUp() override + { + char *path = std::getenv("ARTIFACTS_PATH"); + if (path == nullptr) + { + throw std::runtime_error("environmental variable ARTIFACTS_PATH required for circle-resizer " + "tests was not not provided"); + } + _test_models_dir = path; + } + +protected: + std::string _test_models_dir; +}; + +TEST_F(ModelDataTest, proper_input_output_shapes) +{ + ModelData model_data(_test_models_dir + "/Add_000.circle"); + EXPECT_TRUE(compare_shapes(model_data.input_shapes(), + std::vector{Shape{1, 4, 4, 3}, Shape{1, 4, 4, 3}})); + EXPECT_TRUE(compare_shapes(model_data.output_shapes(), std::vector{Shape{1, 4, 4, 3}})); +} + +TEST_F(ModelDataTest, proper_output_stream) +{ + ModelData model_data(_test_models_dir + "/Add_000.circle"); + std::stringstream out_stream; + model_data.save(out_stream); + out_stream.seekg(0, std::ios::end); + EXPECT_TRUE(out_stream.tellg() > 0); +} + +TEST_F(ModelDataTest, invalidate_module) +{ + ModelData model_data(_test_models_dir + "/Add_000.circle"); + const auto module_before_name_change = model_data.module(); + const std::string new_subgraph_name = "abcd"; + ASSERT_TRUE(change_subgraph_name(model_data.buffer(), new_subgraph_name)); + model_data.invalidate_module(); // after buffer representation change the module is outdated + const auto module_after_name_change = model_data.module(); + EXPECT_EQ(module_after_name_change->graph()->name(), + new_subgraph_name); // check if buffer update applied to the module +} + +TEST_F(ModelDataTest, invalidate_buffer) +{ + ModelData model_data(_test_models_dir + "/Add_000.circle"); + const auto buffer_before_name_change = model_data.buffer(); + const std::string new_subgraph_name = "abcd"; + change_subgraph_name(model_data.module()->graph(), new_subgraph_name); + model_data.invalidate_buffer(); // after module representation change the buffer is outdated + const auto buffer_after_name_change = model_data.buffer(); + EXPECT_EQ(extract_subgraph_name(buffer_after_name_change), + new_subgraph_name); // check if module update applied to the buffer +} + +TEST_F(ModelDataTest, model_file_not_exist_NEG) +{ + auto file_name = "/not_existed.circle"; + try + { + ModelData model_data(file_name); + FAIL() << "Expected std::runtime_error"; + } + catch (const std::runtime_error &err) + { + EXPECT_THAT(err.what(), HasSubstr("Failed to open file")); + EXPECT_THAT(err.what(), HasSubstr(file_name)); + } + catch (...) + { + FAIL() << "Expected std::runtime_error, other exception thrown"; + } +} + +TEST_F(ModelDataTest, invalid_model_NEG) +{ + try + { + ModelData(std::vector{1, 2, 3, 4, 5}); + FAIL() << "Expected std::runtime_error"; + } + catch (const std::runtime_error &err) + { + EXPECT_THAT(err.what(), HasSubstr("Verification of the model failed")); + } + catch (...) + { + FAIL() << "Expected std::runtime_error, other exception thrown"; + } +} + +TEST_F(ModelDataTest, incorrect_output_stream_NEG) +{ + auto model_data = std::make_shared(_test_models_dir + "/Add_000.circle"); + std::ofstream out_stream; + try + { + model_data->save(out_stream); + } + catch (const std::runtime_error &err) + { + EXPECT_THAT(err.what(), HasSubstr("Failed to write to output stream")); + } + catch (...) + { + FAIL() << "Expected std::runtime_error, other exception thrown"; + } +}