Skip to content

Commit 794a7bf

Browse files
authored
Fix incorrect creation of python arrays in Tensor.data (#440)
Previously the Python array would just be filled with the first element of the data repeated to fill the array. Fixes a number of the Python tests. Signed-off-by: Robert Quill <robert.quill@imgtec.com>
1 parent 7024bb4 commit 794a7bf

File tree

1 file changed

+62
-26
lines changed

1 file changed

+62
-26
lines changed

python/src/main.cpp

Lines changed: 62 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -163,24 +163,39 @@ PYBIND11_MODULE(kp, m)
163163
m, "Tensor", DOC(kp, Tensor))
164164
.def(
165165
"data",
166-
[](kp::Tensor& self) {
166+
[](kp::Tensor& self) -> py::array {
167167
// Non-owning container exposing the underlying pointer
168168
switch (self.dataType()) {
169169
case kp::Memory::DataTypes::eFloat:
170-
return py::array(
171-
self.size(), self.data<float>(), py::cast(&self));
170+
return py::array_t<float>(
171+
{static_cast<py::ssize_t>(self.size())}, // shape
172+
{sizeof(float)}, // strides
173+
self.data<float>(), // ptr
174+
py::cast(&self)); // parent
172175
case kp::Memory::DataTypes::eUnsignedInt:
173-
return py::array(
174-
self.size(), self.data<uint32_t>(), py::cast(&self));
176+
return py::array_t<uint32_t>(
177+
{static_cast<py::ssize_t>(self.size())}, // shape
178+
{sizeof(uint32_t)}, // strides
179+
self.data<uint32_t>(), // ptr
180+
py::cast(&self)); // parent
175181
case kp::Memory::DataTypes::eInt:
176-
return py::array(
177-
self.size(), self.data<int32_t>(), py::cast(&self));
182+
return py::array_t<int32_t>(
183+
{static_cast<py::ssize_t>(self.size())}, // shape
184+
{sizeof(int32_t)}, // strides
185+
self.data<int32_t>(), // ptr
186+
py::cast(&self)); // parent
178187
case kp::Memory::DataTypes::eDouble:
179-
return py::array(
180-
self.size(), self.data<double>(), py::cast(&self));
188+
return py::array_t<double>(
189+
{static_cast<py::ssize_t>(self.size())}, // shape
190+
{sizeof(double)}, // strides
191+
self.data<double>(), // ptr
192+
py::cast(&self)); // parent
181193
case kp::Memory::DataTypes::eBool:
182-
return py::array(
183-
self.size(), self.data<bool>(), py::cast(&self));
194+
return py::array_t<bool>(
195+
{static_cast<py::ssize_t>(self.size())}, // shape
196+
{sizeof(bool)}, // strides
197+
self.data<bool>(), // ptr
198+
py::cast(&self)); // parent
184199
default:
185200
throw std::runtime_error(
186201
"Kompute Python data type not supported");
@@ -200,30 +215,51 @@ PYBIND11_MODULE(kp, m)
200215
m, "Image", DOC(kp, Image))
201216
.def(
202217
"data",
203-
[](kp::Image& self) {
218+
[](kp::Image& self) -> py::array {
204219
// Non-owning container exposing the underlying pointer
205220
switch (self.dataType()) {
206221
case kp::Memory::DataTypes::eFloat:
207-
return py::array(
208-
self.size(), self.data<float>(), py::cast(&self));
222+
return py::array_t<float>(
223+
{static_cast<py::ssize_t>(self.size())}, // shape
224+
{sizeof(float)}, // strides
225+
self.data<float>(), // ptr
226+
py::cast(&self)); // parent
209227
case kp::Memory::DataTypes::eUnsignedInt:
210-
return py::array(
211-
self.size(), self.data<uint32_t>(), py::cast(&self));
228+
return py::array_t<uint32_t>(
229+
{static_cast<py::ssize_t>(self.size())}, // shape
230+
{sizeof(uint32_t)}, // strides
231+
self.data<uint32_t>(), // ptr
232+
py::cast(&self)); // parent
212233
case kp::Memory::DataTypes::eInt:
213-
return py::array(
214-
self.size(), self.data<int32_t>(), py::cast(&self));
234+
return py::array_t<int32_t>(
235+
{static_cast<py::ssize_t>(self.size())}, // shape
236+
{sizeof(int32_t)}, // strides
237+
self.data<int32_t>(), // ptr
238+
py::cast(&self)); // parent
215239
case kp::Memory::DataTypes::eUnsignedShort:
216-
return py::array(
217-
self.size(), self.data<uint16_t>(), py::cast(&self));
240+
return py::array_t<uint16_t>(
241+
{static_cast<py::ssize_t>(self.size())}, // shape
242+
{sizeof(uint16_t)}, // strides
243+
self.data<uint16_t>(), // ptr
244+
py::cast(&self)); // parent
218245
case kp::Memory::DataTypes::eShort:
219-
return py::array(
220-
self.size(), self.data<int16_t>(), py::cast(&self));
246+
return py::array_t<int16_t>(
247+
{static_cast<py::ssize_t>(self.size())}, // shape
248+
{sizeof(int16_t)}, // strides
249+
self.data<int16_t>(), // ptr
250+
py::cast(&self)); // parent
221251
case kp::Memory::DataTypes::eUnsignedChar:
222-
return py::array(
223-
self.size(), self.data<uint8_t>(), py::cast(&self));
252+
return py::array_t<uint8_t>(
253+
{static_cast<py::ssize_t>(self.size())}, // shape
254+
{sizeof(uint8_t)}, // strides
255+
self.data<uint8_t>(), // ptr
256+
py::cast(&self)); // parent
224257
case kp::Memory::DataTypes::eChar:
225-
return py::array(
226-
self.size(), self.data<int8_t>(), py::cast(&self));
258+
return py::array_t<int8_t>(
259+
{static_cast<py::ssize_t>(self.size())}, // shape
260+
{sizeof(int8_t)}, // strides
261+
self.data<int8_t>(), // ptr
262+
py::cast(&self)); // parent
227263
default:
228264
throw std::runtime_error(
229265
"Kompute Python data type not supported");

0 commit comments

Comments
 (0)