Skip to content

Commit c067e22

Browse files
mempfWovchena
andauthored
Fix yolo detection results for models created with the 2022.1 model optimizer (#3463)
* Fix yolo detection results for models created with the 2022.1 model optimizer * object_detection_demo/cpp: fix yolo-v1-tiny-tf Co-authored-by: Wovchena <[email protected]>
1 parent 87e8585 commit c067e22

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

demos/common/cpp/models/src/detection_model_yolo.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,14 @@ void ModelYolo::prepareInputsOutputs(std::shared_ptr<ov::Model>& model) {
154154
for (auto& out : outputs) {
155155
ppp.output(out.get_any_name()).tensor().set_element_type(ov::element::f32);
156156
if (out.get_shape().size() == 4) {
157-
if (out.get_shape()[ov::layout::height_idx(yoloRegionLayout)] !=
158-
out.get_shape()[ov::layout::width_idx(yoloRegionLayout)] &&
159-
out.get_shape()[ov::layout::height_idx({"NHWC"})] == out.get_shape()[ov::layout::width_idx({"NHWC"})]) {
160-
yoloRegionLayout = {"NHWC"};
157+
if (out.get_shape()[ov::layout::height_idx("NCHW")] != out.get_shape()[ov::layout::width_idx("NCHW")] &&
158+
out.get_shape()[ov::layout::height_idx("NHWC")] == out.get_shape()[ov::layout::width_idx("NHWC")]) {
159+
ppp.output(out.get_any_name()).model().set_layout("NHWC");
160+
// outShapes are saved before ppp.build() thus set yoloRegionLayout as it is in model before ppp.build()
161+
yoloRegionLayout = "NHWC";
161162
}
162-
ppp.output(out.get_any_name()).tensor().set_layout(yoloRegionLayout);
163+
// yolo-v1-tiny-tf out shape is [1, 21125] thus set layout only for 4 dim tensors
164+
ppp.output(out.get_any_name()).tensor().set_layout("NCHW");
163165
}
164166
outputsNames.push_back(out.get_any_name());
165167
outShapes[out.get_any_name()] = out.get_shape();
@@ -223,7 +225,7 @@ void ModelYolo::prepareInputsOutputs(std::shared_ptr<ov::Model>& model) {
223225
for (const auto& name : outputsNames) {
224226
const auto& shape = outShapes[name];
225227
if (shape[ov::layout::channels_idx(yoloRegionLayout)] % num != 0) {
226-
throw std::logic_error(std::string("Output tenosor ") + name + " has wrong 2nd dimension");
228+
throw std::logic_error(std::string("Output tensor ") + name + " has wrong channel dimension");
227229
}
228230
regions.emplace(
229231
name,
@@ -328,8 +330,8 @@ void ModelYolo::parseYOLOOutput(const std::string& output_name,
328330
case YOLO_V4:
329331
case YOLO_V4_TINY:
330332
case YOLOF:
331-
sideH = static_cast<int>(tensor.get_shape()[ov::layout::height_idx(yoloRegionLayout)]);
332-
sideW = static_cast<int>(tensor.get_shape()[ov::layout::width_idx(yoloRegionLayout)]);
333+
sideH = static_cast<int>(tensor.get_shape()[ov::layout::height_idx("NCHW")]);
334+
sideW = static_cast<int>(tensor.get_shape()[ov::layout::width_idx("NCHW")]);
333335
scaleW = resized_im_w;
334336
scaleH = resized_im_h;
335337
break;

0 commit comments

Comments
 (0)