Skip to content

Commit 533563e

Browse files
authored
set output precisions (#3286)
1 parent 7bc60b4 commit 533563e

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

demos/common/cpp/models/src/classification_model.cpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ ClassificationModel::ClassificationModel(const std::string& modelFileName, size_
3232
labels(labels) {}
3333

3434
std::unique_ptr<ResultBase> ClassificationModel::postprocess(InferenceResult& infResult) {
35-
const ov::Tensor& scoresTensor = infResult.outputsData.find(outputsNames[0])->second;
36-
const float* scoresPtr = scoresTensor.data<float>();
37-
const ov::Tensor& indicesTensor = infResult.outputsData.find(outputsNames[1])->second;
35+
const ov::Tensor& indicesTensor = infResult.outputsData.find(outputsNames[0])->second;
3836
const int* indicesPtr = indicesTensor.data<int>();
37+
const ov::Tensor& scoresTensor = infResult.outputsData.find(outputsNames[1])->second;
38+
const float* scoresPtr = scoresTensor.data<float>();
3939

4040
ClassificationResult* result = new ClassificationResult(infResult.frameId, infResult.metaData);
4141
auto retVal = std::unique_ptr<ResultBase>(result);
@@ -165,13 +165,20 @@ void ClassificationModel::prepareInputsOutputs(std::shared_ptr<ov::Model>& model
165165
ov::op::v3::TopK::Mode::MAX,
166166
ov::op::v3::TopK::SortType::SORT_VALUES);
167167

168-
auto scores = std::make_shared<ov::op::v0::Result>(topkNode->output(0));
169-
auto indices = std::make_shared<ov::op::v0::Result>(topkNode->output(1));
168+
auto indices = std::make_shared<ov::op::v0::Result>(topkNode->output(0));
169+
auto scores = std::make_shared<ov::op::v0::Result>(topkNode->output(1));
170170
ov::ResultVector res({ scores, indices });
171171
model = std::make_shared<ov::Model>(res, model->get_parameters(), "classification");
172+
172173
// manually set output tensors name for created topK node
173-
model->outputs()[0].set_names({"indices"});
174+
model->outputs()[0].set_names({ "indices" });
174175
outputsNames.push_back("indices");
175-
model->outputs()[1].set_names({"scores"});
176+
model->outputs()[1].set_names({ "scores" });
176177
outputsNames.push_back("scores");
178+
179+
// set output precisions
180+
ppp = ov::preprocess::PrePostProcessor(model);
181+
ppp.output("indices").tensor().set_element_type(ov::element::i32);
182+
ppp.output("scores").tensor().set_element_type(ov::element::f32);
183+
model = ppp.build();
177184
}

0 commit comments

Comments
 (0)