Skip to content

Commit 5ce76c3

Browse files
committed
match API to python version, classification, keypoint detection
1 parent d706107 commit 5ce76c3

File tree

2 files changed

+36
-11
lines changed

2 files changed

+36
-11
lines changed

src/cpp/py_bindings/py_classification.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ void init_classification(nb::module_& m) {
2020
nb::class_<ClassificationResult::Classification>(m, "Classification")
2121
.def(nb::init<unsigned int, const std::string, float>())
2222
.def_rw("id", &ClassificationResult::Classification::id)
23-
.def_rw("label", &ClassificationResult::Classification::label)
23+
.def_rw("name", &ClassificationResult::Classification::label)
2424
.def_rw("score", &ClassificationResult::Classification::score);
2525

2626
nb::class_<ClassificationResult, ResultBase>(m, "ClassificationResult")
@@ -39,6 +39,18 @@ void init_classification(nb::module_& m) {
3939
r.feature_vector.get_shape().data());
4040
},
4141
nb::rv_policy::reference_internal)
42+
.def_prop_ro(
43+
"raw_scores",
44+
[](ClassificationResult& r) {
45+
if (!r.raw_scores) {
46+
return nb::ndarray<float, nb::numpy, nb::c_contig>();
47+
}
48+
49+
return nb::ndarray<float, nb::numpy, nb::c_contig>(r.raw_scores.data(),
50+
r.raw_scores.get_shape().size(),
51+
r.raw_scores.get_shape().data());
52+
},
53+
nb::rv_policy::reference_internal)
4254
.def_prop_ro(
4355
"saliency_map",
4456
[](ClassificationResult& r) {

src/cpp/py_bindings/py_keypoint_detection.cpp

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,20 +50,33 @@ void init_keypoint_detection(nb::module_& m) {
5050

5151
return self.inferBatch(input_mats);
5252
})
53-
.def("postprocess",
54-
[](KeypointDetectionModel& self, InferenceResult& infResult) {
55-
return self.postprocess(infResult);
56-
})
5753
.def_prop_ro_static("__model__", [](nb::object) {
5854
return KeypointDetectionModel::ModelType;
5955
});
6056

6157
nb::class_<KeypointDetectionResult, ResultBase>(m, "KeypointDetectionResult")
6258
.def(nb::init<int64_t, std::shared_ptr<MetaData>>(), nb::arg("frameId") = -1, nb::arg("metaData") = nullptr)
63-
.def_ro("poses", &KeypointDetectionResult::poses);
64-
65-
nb::class_<DetectedKeypoints>(m, "DetectedKeypoints")
66-
.def(nb::init<>())
67-
.def_ro("keypoints", &DetectedKeypoints::keypoints)
68-
.def_ro("scores", &DetectedKeypoints::scores);
59+
.def_prop_ro("keypoints", [](const KeypointDetectionResult& result) {
60+
if (!result.poses.empty()) {
61+
std::vector<size_t> shape = {result.poses[0].keypoints.size(), 2};
62+
return nb::ndarray<float, nb::numpy, nb::c_contig>(
63+
const_cast<void*>(static_cast<const void*>(result.poses[0].keypoints.data())),
64+
shape.size(),
65+
shape.data());
66+
}
67+
return nb::ndarray<float, nb::numpy, nb::c_contig>();
68+
},
69+
nb::rv_policy::reference_internal)
70+
.def_prop_ro("scores", [](const KeypointDetectionResult& result) {
71+
if (!result.poses.empty()) {
72+
std::vector<size_t> shape = {result.poses[0].scores.size()};
73+
return nb::ndarray<float, nb::numpy, nb::c_contig>(
74+
const_cast<void*>(static_cast<const void*>(result.poses[0].scores.data())),
75+
shape.size(),
76+
shape.data());
77+
}
78+
return nb::ndarray<float, nb::numpy, nb::c_contig>();
79+
},
80+
nb::rv_policy::reference_internal
81+
);
6982
}

0 commit comments

Comments
 (0)