Skip to content

Commit 616c271

Browse files
committed
Fix classification serialization
Had to set ov::batch(model, 1);
1 parent 9384335 commit 616c271

File tree

4 files changed

+60
-15
lines changed

4 files changed

+60
-15
lines changed

src/cpp/include/utils/config.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,32 @@ static inline std::tuple<bool, ov::Layout> makeGuesLayoutFrom4DShape(const ov::P
9797
return {false, ov::Layout{}};
9898
}
9999

100+
static inline std::map<std::string, ov::Layout> parseLayoutString(const std::string& layout_string) {
101+
// Parse parameter string like "input0:NCHW,input1:NC" or "NCHW" (applied to all
102+
// inputs)
103+
std::map<std::string, ov::Layout> layouts;
104+
std::string searchStr =
105+
(layout_string.find_last_of(':') == std::string::npos && !layout_string.empty() ? ":" : "") + layout_string;
106+
auto colonPos = searchStr.find_last_of(':');
107+
while (colonPos != std::string::npos) {
108+
auto startPos = searchStr.find_last_of(',');
109+
auto inputName = searchStr.substr(startPos + 1, colonPos - startPos - 1);
110+
auto inputLayout = searchStr.substr(colonPos + 1);
111+
layouts[inputName] = ov::Layout(inputLayout);
112+
searchStr.resize(startPos + 1);
113+
if (searchStr.empty() || searchStr.back() != ',') {
114+
break;
115+
}
116+
searchStr.pop_back();
117+
colonPos = searchStr.find_last_of(':');
118+
}
119+
if (!searchStr.empty()) {
120+
throw std::invalid_argument("Can't parse input layout string: " + layout_string);
121+
}
122+
return layouts;
123+
}
124+
125+
100126
static inline ov::Layout getLayoutFromShape(const ov::PartialShape& shape) {
101127
if (shape.size() == 2) {
102128
return "NC";
@@ -133,4 +159,21 @@ static inline ov::Layout getLayoutFromShape(const ov::PartialShape& shape) {
133159
throw std::runtime_error("Usupported " + std::to_string(shape.size()) + "D shape");
134160
}
135161

162+
static inline ov::Layout getInputLayout(const ov::Output<ov::Node>& input, std::map<std::string, ov::Layout>& inputsLayouts) {
163+
ov::Layout layout = ov::layout::get_layout(input);
164+
if (layout.empty()) {
165+
if (inputsLayouts.empty()) {
166+
layout = getLayoutFromShape(input.get_partial_shape());
167+
std::cout << "Automatically detected layout '" << layout.to_string() << "' for input '"
168+
<< input.get_any_name() << "' will be used." << std::endl;
169+
} else if (inputsLayouts.size() == 1) {
170+
layout = inputsLayouts.begin()->second;
171+
} else {
172+
layout = inputsLayouts[input.get_any_name()];
173+
}
174+
}
175+
176+
return layout;
177+
}
178+
136179
} // namespace utils

src/cpp/src/tasks/classification.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,11 @@ std::vector<size_t> get_non_xai_output_indices(const std::vector<ov::Output<ov::
8585
cv::Size Classification::serialize(std::shared_ptr<ov::Model>& ov_model) {
8686
// --------------------------- Configure input & output -------------------------------------------------
8787
// --------------------------- Prepare input ------------------------------------------------------
88+
auto config = ov_model->has_rt_info("model_info") ? ov_model->get_rt_info<ov::AnyMap>("model_info") : ov::AnyMap{};
89+
std::string layout = "";
90+
layout = utils::get_from_any_maps("layout", config, {}, layout);
91+
auto inputsLayouts = utils::parseLayoutString(layout);
92+
8893
if (ov_model->inputs().size() != 1) {
8994
throw std::logic_error("Classification model wrapper supports topologies with only 1 input");
9095
}
@@ -93,9 +98,7 @@ cv::Size Classification::serialize(std::shared_ptr<ov::Model>& ov_model) {
9398
auto inputName = input.get_any_name();
9499

95100
const ov::Shape& inputShape = input.get_partial_shape().get_max_shape();
96-
const ov::Layout& inputLayout = utils::getLayoutFromShape(inputShape);
97-
98-
auto config = ov_model->has_rt_info("model_info") ? ov_model->get_rt_info<ov::AnyMap>("model_info") : ov::AnyMap{};
101+
const ov::Layout& inputLayout = utils::getInputLayout(input, inputsLayouts);
99102

100103
auto interpolation_mode = cv::INTER_LINEAR;
101104
utils::RESIZE_MODE resize_mode = utils::RESIZE_FILL;
@@ -121,9 +124,6 @@ cv::Size Classification::serialize(std::shared_ptr<ov::Model>& ov_model) {
121124
mean_values,
122125
scale_values);
123126

124-
ov::preprocess::PrePostProcessor ppp = ov::preprocess::PrePostProcessor(ov_model);
125-
ov_model = ppp.build();
126-
127127
// --------------------------- Prepare output -----------------------------------------------------
128128
if (ov_model->outputs().size() > 5) {
129129
throw std::logic_error("Classification model wrapper supports topologies with up to 4 outputs");

src/cpp/src/utils/preprocessing.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ std::shared_ptr<ov::Model> embedProcessing(std::shared_ptr<ov::Model>& model,
2424
const std::vector<float>& scale,
2525
const std::type_info& dtype) {
2626
ov::preprocess::PrePostProcessor ppp(model);
27+
2728
// Change the input type to the 8-bit image
2829
if (dtype == typeid(int)) {
2930
ppp.input(inputName).tensor().set_element_type(ov::element::u8);
@@ -54,7 +55,9 @@ std::shared_ptr<ov::Model> embedProcessing(std::shared_ptr<ov::Model>& model,
5455
ppp.input(inputName).preprocess().scale(scale);
5556
}
5657

57-
return ppp.build();
58+
auto ov_model = ppp.build();
59+
ov::set_batch(ov_model, 1);
60+
return ov_model;
5861
}
5962

6063
ov::preprocess::PostProcessSteps::CustomPostprocessOp createResizeGraph(RESIZE_MODE resizeMode,

tests/cpp/test_accuracy.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,14 +103,13 @@ TEST_P(ModelParameterizedTest, AccuracyTest) {
103103
} else if (data.type == "MaskRCNNModel") {
104104
GTEST_SKIP();
105105
} else if (data.type == "ClassificationModel") {
106-
GTEST_SKIP();
107-
//auto model = Classification::load(model_path);
108-
//for (auto& test_data : data.test_data) {
109-
// std::string image_path = DATA_DIR + '/' + test_data.image;
110-
// cv::Mat image = cv::imread(image_path);
111-
// auto result = model.infer(image);
112-
// EXPECT_EQ(std::string{result}, test_data.reference[0]);
113-
//}
106+
auto model = Classification::load(model_path);
107+
for (auto& test_data : data.test_data) {
108+
std::string image_path = DATA_DIR + '/' + test_data.image;
109+
cv::Mat image = cv::imread(image_path);
110+
auto result = model.infer(image);
111+
EXPECT_EQ(std::string{result}, test_data.reference[0]);
112+
}
114113
} else {
115114
FAIL() << "No implementation for model type " << data.type;
116115
}

0 commit comments

Comments
 (0)