diff --git a/python/src/main.cpp b/python/src/main.cpp index 94e4b134..8bec2970 100644 --- a/python/src/main.cpp +++ b/python/src/main.cpp @@ -163,24 +163,39 @@ PYBIND11_MODULE(kp, m) m, "Tensor", DOC(kp, Tensor)) .def( "data", - [](kp::Tensor& self) { + [](kp::Tensor& self) -> py::array { // Non-owning container exposing the underlying pointer switch (self.dataType()) { case kp::Memory::DataTypes::eFloat: - return py::array( - self.size(), self.data(), py::cast(&self)); + return py::array_t( + {static_cast(self.size())}, // shape + {sizeof(float)}, // strides + self.data(), // ptr + py::cast(&self)); // parent case kp::Memory::DataTypes::eUnsignedInt: - return py::array( - self.size(), self.data(), py::cast(&self)); + return py::array_t( + {static_cast(self.size())}, // shape + {sizeof(uint32_t)}, // strides + self.data(), // ptr + py::cast(&self)); // parent case kp::Memory::DataTypes::eInt: - return py::array( - self.size(), self.data(), py::cast(&self)); + return py::array_t( + {static_cast(self.size())}, // shape + {sizeof(int32_t)}, // strides + self.data(), // ptr + py::cast(&self)); // parent case kp::Memory::DataTypes::eDouble: - return py::array( - self.size(), self.data(), py::cast(&self)); + return py::array_t( + {static_cast(self.size())}, // shape + {sizeof(double)}, // strides + self.data(), // ptr + py::cast(&self)); // parent case kp::Memory::DataTypes::eBool: - return py::array( - self.size(), self.data(), py::cast(&self)); + return py::array_t( + {static_cast(self.size())}, // shape + {sizeof(bool)}, // strides + self.data(), // ptr + py::cast(&self)); // parent default: throw std::runtime_error( "Kompute Python data type not supported"); @@ -200,30 +215,51 @@ PYBIND11_MODULE(kp, m) m, "Image", DOC(kp, Image)) .def( "data", - [](kp::Image& self) { + [](kp::Image& self) -> py::array { // Non-owning container exposing the underlying pointer switch (self.dataType()) { case kp::Memory::DataTypes::eFloat: - return py::array( - self.size(), self.data(), py::cast(&self)); + return py::array_t( + {static_cast(self.size())}, // shape + {sizeof(float)}, // strides + self.data(), // ptr + py::cast(&self)); // parent case kp::Memory::DataTypes::eUnsignedInt: - return py::array( - self.size(), self.data(), py::cast(&self)); + return py::array_t( + {static_cast(self.size())}, // shape + {sizeof(uint32_t)}, // strides + self.data(), // ptr + py::cast(&self)); // parent case kp::Memory::DataTypes::eInt: - return py::array( - self.size(), self.data(), py::cast(&self)); + return py::array_t( + {static_cast(self.size())}, // shape + {sizeof(int32_t)}, // strides + self.data(), // ptr + py::cast(&self)); // parent case kp::Memory::DataTypes::eUnsignedShort: - return py::array( - self.size(), self.data(), py::cast(&self)); + return py::array_t( + {static_cast(self.size())}, // shape + {sizeof(uint16_t)}, // strides + self.data(), // ptr + py::cast(&self)); // parent case kp::Memory::DataTypes::eShort: - return py::array( - self.size(), self.data(), py::cast(&self)); + return py::array_t( + {static_cast(self.size())}, // shape + {sizeof(int16_t)}, // strides + self.data(), // ptr + py::cast(&self)); // parent case kp::Memory::DataTypes::eUnsignedChar: - return py::array( - self.size(), self.data(), py::cast(&self)); + return py::array_t( + {static_cast(self.size())}, // shape + {sizeof(uint8_t)}, // strides + self.data(), // ptr + py::cast(&self)); // parent case kp::Memory::DataTypes::eChar: - return py::array( - self.size(), self.data(), py::cast(&self)); + return py::array_t( + {static_cast(self.size())}, // shape + {sizeof(int8_t)}, // strides + self.data(), // ptr + py::cast(&self)); // parent default: throw std::runtime_error( "Kompute Python data type not supported");