Skip to content

Commit eea5e80

Browse files
committed
implement changes
1 parent de33b27 commit eea5e80

File tree

5 files changed

+48
-9
lines changed

5 files changed

+48
-9
lines changed

bindings/python/src/pipeline/datatype/ImgDetectionsBindings.cpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,19 +118,34 @@ void bind_imgdetections(pybind11::module& m, void* pCallstack) {
118118
[](ImgDetections& det, std::vector<ImgDetection> val) { det.detections = val; },
119119
DOC(dai, ImgDetectionsT, detections),
120120
py::return_value_policy::reference_internal)
121+
.def_property(
122+
"segmentationMaskWidth",
123+
[](ImgDetections& det) { return &det.segmentationMaskWidth; },
124+
[](ImgDetections& det, size_t val) { det.segmentationMaskWidth = val; },
125+
DOC(dai, ImgDetectionsT, segmentationMaskWidth),
126+
py::return_value_policy::reference_internal)
127+
.def_property(
128+
"segmentationMaskHeight",
129+
[](ImgDetections& det) { return &det.segmentationMaskHeight; },
130+
[](ImgDetections& det, size_t val) { det.segmentationMaskHeight = val; },
131+
DOC(dai, ImgDetectionsT, segmentationMaskHeight),
132+
py::return_value_policy::reference_internal)
121133
.def("getTimestamp", &dai::ImgDetectionsT<dai::ImgDetection>::Buffer::getTimestamp, DOC(dai, Buffer, getTimestamp))
122134
.def("getTimestampDevice", &dai::ImgDetectionsT<dai::ImgDetection>::Buffer::getTimestampDevice, DOC(dai, Buffer, getTimestampDevice))
123135
.def("getSequenceNum", &dai::ImgDetectionsT<dai::ImgDetection>::Buffer::getSequenceNum, DOC(dai, Buffer, getSequenceNum))
124136
.def("getTransformation", [](ImgDetections& msg) { return msg.transformation; })
125137
.def("setTransformation", [](ImgDetections& msg, const std::optional<ImgTransformation>& transformation) { msg.transformation = transformation; })
126138
.def("getSegmentationMaskWidth", &ImgDetections::getSegmentationMaskWidth, DOC(dai, ImgDetectionsT, getSegmentationMaskWidth))
127139
.def("getSegmentationMaskHeight", &ImgDetections::getSegmentationMaskHeight, DOC(dai, ImgDetectionsT, getSegmentationMaskHeight))
128-
.def(
129-
"setMask", &ImgDetections::setSegmentationMask, py::arg("mask"), py::arg("width"), py::arg("height"), DOC(dai, ImgDetectionsT, setSegmentationMask))
140+
.def("setSegmentationMask",
141+
py::overload_cast<dai::ImgFrame&>(&ImgDetections::setSegmentationMask),
142+
py::arg("frame"),
143+
DOC(dai, ImgDetectionsT, setSegmentationMask),
144+
py::return_value_policy::reference_internal)
130145
.def("getMaskData", &ImgDetections::getMaskData, DOC(dai, ImgDetectionsT, getMaskData))
131-
.def("getSegmentationMaskAsImgFrame", &ImgDetections::getSegmentationMask, DOC(dai, ImgDetectionsT, getSegmentationMask))
146+
.def("getSegmentationMask", &ImgDetections::getSegmentationMask, DOC(dai, ImgDetectionsT, getSegmentationMask))
132147
#ifdef DEPTHAI_HAVE_OPENCV_SUPPORT
133-
.def("setSegmentationMask", &ImgDetections::setCvSegmentationMask, py::arg("mask"), DOC(dai, ImgDetectionsT, setCvSegmentationMask))
148+
.def("setCvSegmentationMask", &ImgDetections::setCvSegmentationMask, py::arg("mask"), DOC(dai, ImgDetectionsT, setCvSegmentationMask))
134149
.def(
135150
"getCvSegmentationMask",
136151
[](ImgDetections& self) { return self.getCvSegmentationMask(&g_numpyAllocator); },

include/depthai/pipeline/datatype/ImgDetections.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,9 @@ struct ImgDetection {
139139
};
140140

141141
/**
142-
* ImgDetections message. Carries normalized detection results
142+
* ImgDetections message. Carries normalized detections and optional segmentation mask.
143+
* The segmentation mask is stored as a single-channel INT8 2-d array, where the value represents the instance index in the list of detections.
144+
* The value 255 is treated as a background pixel (no instance).
143145
*/
144146
class ImgDetections : public ImgDetectionsT<ImgDetection>, public ProtoSerializable {
145147
public:

include/depthai/pipeline/datatype/ImgDetectionsT.hpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,25 @@ class ImgDetectionsT : public Buffer {
4545
std::size_t getSegmentationMaskHeight() const;
4646

4747
/*
48-
* Sets the segmentation mask from a vector of bytes, along with width and height.
48+
* Sets the segmentation mask from a vector of bytes.
4949
* The size of the vector must be equal to width * height.
5050
*/
5151
void setSegmentationMask(const std::vector<std::uint8_t>& mask, size_t width, size_t height);
5252

53+
/*
54+
* Sets the segmentation mask from an ImgFrame.
55+
* @param frame Frame must be of type GRAY8
56+
*/
57+
void setSegmentationMask(dai::ImgFrame& frame);
58+
5359
/*
5460
* Returns a copy of the segmentation mask data as a vector of bytes. If mask data is not set, returns std::nullopt.
5561
*/
5662
std::optional<std::vector<std::uint8_t>> getMaskData() const;
5763

64+
/*
65+
* Returns the segmentation mask as an ImgFrame. If mask data is not set, returns std::nullopt.
66+
*/
5867
std::optional<dai::ImgFrame> getSegmentationMask() const;
5968

6069
// Optional - OpenCV support
@@ -71,7 +80,7 @@ class ImgDetectionsT : public Buffer {
7180
void setCvSegmentationMask(cv::Mat mask);
7281

7382
/**
74-
* Retrieves data as cv::Mat with specified width and height. If mask data is not set, returns std::nullopt.
83+
* Retrieves mask data as a cv::Mat copy with specified width and height. If mask data is not set, returns std::nullopt.
7584
* @param allocator Allows callers to supply a custom cv::MatAllocator for zero-copy/custom memory management; nullptr uses OpenCV’s default.
7685
*/
7786
std::optional<cv::Mat> getCvSegmentationMask(cv::MatAllocator* allocator = nullptr);

src/pipeline/datatype/ImgDetectionsT.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,18 @@ void ImgDetectionsT<DetectionT>::setSegmentationMask(const std::vector<std::uint
3131
this->segmentationMaskHeight = height;
3232
}
3333

34+
template <class DetectionT>
35+
void ImgDetectionsT<DetectionT>::setSegmentationMask(dai::ImgFrame& frame) {
36+
if(frame.getType() != dai::ImgFrame::Type::GRAY8) {
37+
throw std::runtime_error("SegmentationMask: ImgFrame type must be GRAY8");
38+
}
39+
auto dataSpan = frame.getData();
40+
std::vector<std::uint8_t> vecMask(dataSpan.begin(), dataSpan.end());
41+
setData(vecMask);
42+
this->segmentationMaskWidth = frame.getWidth();
43+
this->segmentationMaskHeight = frame.getHeight();
44+
}
45+
3446
template <class DetectionT>
3547
std::optional<std::vector<std::uint8_t>> ImgDetectionsT<DetectionT>::getMaskData() const {
3648
const auto& d = data->getData();

tests/src/onhost_tests/pipeline/datatype/imgdetections_test.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ TEST_CASE("ImgDetections segmentation mask operations", "[ImgDetections][Segment
274274
}
275275

276276
#ifdef DEPTHAI_HAVE_OPENCV_SUPPORT
277-
SECTION("OpenCV segmentation mask view semantics") {
277+
SECTION("OpenCV segmentation mask copy semantics") {
278278
ImgDetections detections;
279279
constexpr int rows = 3;
280280
constexpr int cols = 4;
@@ -303,7 +303,8 @@ TEST_CASE("ImgDetections segmentation mask operations", "[ImgDetections][Segment
303303
auto shallowData = *optShallowData;
304304

305305
REQUIRE_FALSE(shallowData.empty());
306-
REQUIRE(shallowData.front() == shallow.at<uint8_t>(0, 0));
306+
REQUIRE(shallowData.front() == mask.at<uint8_t>(0, 0));
307+
REQUIRE(shallowData.front() != shallow.at<uint8_t>(0, 0));
307308

308309
cv::Mat constantMask(rows, cols, CV_8UC1, cv::Scalar(7));
309310
detections.setCvSegmentationMask(constantMask);

0 commit comments

Comments
 (0)