|
| 1 | +#include "Sampler.h" |
| 2 | +#include "hoomd/HOOMDMath.h" |
| 3 | +#include "dlpack.h" |
| 4 | +#include <stdexcept> |
| 5 | + |
| 6 | +using namespace std; |
| 7 | +namespace py = pybind11; |
| 8 | + |
| 9 | +const char* const kDLTensorCapsuleName = "dltensor"; |
| 10 | +constexpr uint8_t kBits = std::is_same<Scalar, float>::value ? 32 : 64; |
| 11 | + |
| 12 | +template <typename> |
| 13 | +constexpr DLDataType dtype(); |
| 14 | +template <> |
| 15 | +constexpr DLDataType dtype<Scalar4>() { return DLDataType {kDLFloat, kBits, 1}; } |
| 16 | +template <> |
| 17 | +constexpr DLDataType dtype<Scalar3>() { return DLDataType {kDLFloat, kBits, 1}; } |
| 18 | +template <> |
| 19 | +constexpr DLDataType dtype<Scalar>() { return DLDataType {kDLFloat, kBits, 1}; } |
| 20 | +template <> |
| 21 | +constexpr DLDataType dtype<int3>() { return DLDataType {kDLInt, 32, 1}; } |
| 22 | +template <> |
| 23 | +constexpr DLDataType dtype<unsigned int>() { return DLDataType {kDLUInt, 32, 1}; } |
| 24 | +template <> |
| 25 | +constexpr DLDataType dtype<int>() { return DLDataType {kDLInt, 32, 1}; } |
| 26 | + |
| 27 | +template <typename> |
| 28 | +constexpr int64_t stride1(); |
| 29 | +template <> |
| 30 | +constexpr int64_t stride1<Scalar4>() { return 4; } |
| 31 | +template <> |
| 32 | +constexpr int64_t stride1<Scalar3>() { return 3; } |
| 33 | +template <> |
| 34 | +constexpr int64_t stride1<Scalar>() { return 1; } |
| 35 | +template <> |
| 36 | +constexpr int64_t stride1<int3>() { return 3; } |
| 37 | +template <> |
| 38 | +constexpr int64_t stride1<unsigned int>() { return 1; } |
| 39 | + |
| 40 | +template <typename T> |
| 41 | +inline void* opaque(T* data) { return static_cast<void*>(data); } |
| 42 | + |
| 43 | +inline py::capsule encapsulate(DLManagedTensor* dl_managed_tensor) |
| 44 | +{ |
| 45 | + return py::capsule(dl_managed_tensor, kDLTensorCapsuleName); |
| 46 | +} |
| 47 | + |
| 48 | +Sampler::Sampler(shared_ptr<SystemDefinition> sysdef, |
| 49 | + py::function python_update) |
| 50 | + : |
| 51 | + HalfStepHook(), |
| 52 | + m_python_update(python_update) |
| 53 | +{ |
| 54 | + this->setSystemDefinition(sysdef); |
| 55 | +} |
| 56 | + |
| 57 | +void Sampler::setSystemDefinition(shared_ptr<SystemDefinition> sysdef) |
| 58 | +{ |
| 59 | + m_sysdef = sysdef; |
| 60 | + m_pdata = sysdef->getParticleData(); |
| 61 | + m_exec_conf = m_pdata->getExecConf(); |
| 62 | +} |
| 63 | + |
| 64 | +void Sampler::run_on_data(py::function py_exec, const access_location::Enum location, const access_mode::Enum mode) |
| 65 | +{ |
| 66 | + if(location == access_location::device and not m_exec_conf->isCUDAEnabled()) |
| 67 | + throw runtime_error("Invalid request for device memory in non-cuda run."); |
| 68 | + |
| 69 | + const bool on_device = location == access_location::device; |
| 70 | + |
| 71 | + const ArrayHandle<Scalar4> pos(m_pdata->getPositions(), location, mode); |
| 72 | + auto pos_bridge = wrap<Scalar4, Scalar>(pos.data, on_device, 4 ); |
| 73 | + auto pos_capsule = encapsulate(&pos_bridge.tensor); |
| 74 | + |
| 75 | + const ArrayHandle<Scalar4> vel(m_pdata->getVelocities(), location, mode); |
| 76 | + auto vel_bridge = wrap<Scalar4, Scalar>(vel.data, on_device, 4 ); |
| 77 | + auto vel_capsule = encapsulate(&vel_bridge.tensor); |
| 78 | + |
| 79 | + const ArrayHandle<unsigned int> rtags(m_pdata->getRTags(), location, mode); |
| 80 | + auto rtags_bridge = wrap<unsigned int, unsigned int>(rtags.data, on_device, 1); |
| 81 | + auto rtags_capsule = encapsulate(&rtags_bridge.tensor); |
| 82 | + |
| 83 | + const ArrayHandle<int3> img(m_pdata->getImages(), location, mode); |
| 84 | + auto img_bridge = wrap<int3, int>(img.data, on_device, 3); |
| 85 | + auto img_capsule = encapsulate(&img_bridge.tensor); |
| 86 | + |
| 87 | + ArrayHandle<Scalar4> force(m_pdata->getNetForce(), location, access_mode::readwrite); |
| 88 | + auto force_bridge = wrap<Scalar4, Scalar>(force.data, on_device, 4 ); |
| 89 | + auto force_capsule = encapsulate(&force_bridge.tensor); |
| 90 | + |
| 91 | + py_exec(pos_capsule, vel_capsule, rtags_capsule, img_capsule, force_capsule); |
| 92 | +} |
| 93 | + |
| 94 | +void Sampler::update(unsigned int timestep) |
| 95 | +{ |
| 96 | + |
| 97 | + // Accessing the handles here holds them valid until the block of this function. |
| 98 | + // This keeps them valid for the python function call |
| 99 | + auto location = m_exec_conf->isCUDAEnabled() ? access_location::device : access_location::host; |
| 100 | + |
| 101 | + // const ArrayHandle<Scalar4> pos(m_pdata->getPositions(), location, access_mode::read); |
| 102 | + // auto pos_tensor = wrap<Scalar4, Scalar>(pos.data, 4 ); |
| 103 | + // ArrayHandle<Scalar4> vel(m_pdata->getVelocities(), location, access_mode::read); |
| 104 | + // auto vel_tensor = wrap<Scalar4, Scalar>(vel.data, 4); |
| 105 | + // ArrayHandle<unsigned int> rtags(m_pdata->getRTags(), location, access_mode::read); |
| 106 | + // auto rtag_tensor = wrap<unsigned int, unsigned int>(rtags.data, 1); |
| 107 | + // ArrayHandle<int3> img(m_pdata->getImages(), location, access_mode::read); |
| 108 | + // auto img_tensor = wrap<int3, int>(img.data, 3); |
| 109 | + |
| 110 | + // ArrayHandle<Scalar4> net_forces(m_pdata->getNetForce(), location, access_mode::readwrite); |
| 111 | + // auto force_tensor = wrap<Scalar4, Scalar>(net_forces.data, 4); |
| 112 | + |
| 113 | + // m_python_update(pos_tensor, vel_tensor, rtag_tensor, img_tensor, force_tensor, |
| 114 | + // m_pdata->getGlobalBox()); |
| 115 | + this->run_on_data(m_python_update, location, access_mode::read); |
| 116 | +} |
| 117 | + |
| 118 | +template <typename TV, typename TS> |
| 119 | +DLDataBridge Sampler::wrap(TV* ptr, |
| 120 | + const bool on_device, |
| 121 | + const int64_t size2, |
| 122 | + const uint64_t offset, |
| 123 | + uint64_t stride1_offset) { |
| 124 | + assert((size2 >= 1)); // assert is a macro so the extra parentheses are requiered here |
| 125 | + |
| 126 | + const unsigned int particle_number = this->m_pdata->getN(); |
| 127 | + const int gpu_id = on_device ? m_exec_conf->getGPUIds()[0] : m_exec_conf->getRank(); |
| 128 | + |
| 129 | + DLDataBridge bridge; |
| 130 | + bridge.tensor.manager_ctx = NULL; |
| 131 | + bridge.tensor.deleter = NULL; |
| 132 | + |
| 133 | + bridge.tensor.dl_tensor.data = opaque(ptr); |
| 134 | + bridge.tensor.dl_tensor.ctx = DLContext{on_device ? kDLGPU : kDLCPU, gpu_id}; |
| 135 | + bridge.tensor.dl_tensor.dtype = dtype<TS>(); |
| 136 | + |
| 137 | + bridge.shape.push_back(particle_number); |
| 138 | + if (size2 > 1) |
| 139 | + bridge.shape.push_back(size2); |
| 140 | + |
| 141 | + bridge.strides.push_back(stride1<TV>() + stride1_offset); |
| 142 | + if (size2 > 1) |
| 143 | + bridge.strides.push_back(1); |
| 144 | + |
| 145 | + bridge.tensor.dl_tensor.ndim = bridge.shape.size(); |
| 146 | + bridge.tensor.dl_tensor.dtype = dtype<TS>(); |
| 147 | + bridge.tensor.dl_tensor.shape = reinterpret_cast<std::int64_t*>(bridge.shape.data()); |
| 148 | + bridge.tensor.dl_tensor.strides = reinterpret_cast<std::int64_t*>(bridge.strides.data()); |
| 149 | + bridge.tensor.dl_tensor.byte_offset = offset; |
| 150 | + |
| 151 | + return bridge; |
| 152 | +} |
| 153 | + |
| 154 | + |
| 155 | +void export_Sampler(py::module& m) |
| 156 | +{ |
| 157 | + py::class_<Sampler, std::shared_ptr<Sampler> >(m, "DLextSampler", py::base<HalfStepHook>()) |
| 158 | + .def(py::init<std::shared_ptr<SystemDefinition>, py::function>()) |
| 159 | + .def("run_on_data", &Sampler::run_on_data) |
| 160 | + ; |
| 161 | +} |
0 commit comments