Skip to content

Commit 667c856

Browse files
committed
fixes
1 parent e38dbc3 commit 667c856

File tree

3 files changed

+10
-6
lines changed

3 files changed

+10
-6
lines changed

demos/common/cpp/models/src/classification_model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ void ClassificationModel::prepareInputsOutputs(std::shared_ptr<ov::Model>& model
116116
" 2-dimensional or 4-dimensional output");
117117
}
118118

119-
const ov::Layout outputLayout("NC...");
119+
const ov::Layout outputLayout("NCHW");
120120
if (outputShape.size() == 4 && (outputShape[ov::layout::height_idx(outputLayout)] != 1
121121
|| outputShape[ov::layout::width_idx(outputLayout)] != 1)) {
122122
throw std::logic_error("Classification model wrapper supports topologies only"

demos/common/cpp/models/src/super_resolution_model.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ void SuperResolutionModel::prepareInputsOutputs(std::shared_ptr<ov::Model>& mode
4444
if (lrShape.size() != 4) {
4545
throw std::logic_error("Number of dimensions for an input must be 4");
4646
}
47-
// in case of 2 inputs they have same layouts
48-
ov::Layout inputLayout = getInputLayout(model->input());
47+
// in case of 2 inputs they have the same layouts
48+
ov::Layout inputLayout = getInputLayout(model->inputs().front());
4949

5050
auto channelsId = ov::layout::channels_idx(inputLayout);
5151
auto heightId = ov::layout::height_idx(inputLayout);

demos/common/cpp/utils/include/utils/ocv_common.hpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,11 +204,15 @@ static inline void resize2tensor(const cv::Mat& mat, const ov::Tensor& tensor) {
204204

205205
static inline ov::Layout getLayoutFromShape(const ov::Shape& shape) {
206206
if (shape.size() == 2) {
207-
return { "NC" };
207+
return "NC";
208+
}
209+
else if (shape.size() == 3) {
210+
return (shape[0] >= 1 && shape[0] <= 4) ? "CHW" :
211+
"HWC";
208212
}
209213
else if (shape.size() == 4) {
210-
return (shape[1] >= 1 && shape[1] <= 4) ? ov::Layout{"NCHW"} :
211-
ov::Layout{"NHWC"};
214+
return (shape[1] >= 1 && shape[1] <= 4) ? "NCHW" :
215+
"NHWC";
212216
}
213217
else {
214218
throw std::runtime_error("Usupported " + std::to_string(shape.size()) + "D shape");

0 commit comments

Comments
 (0)