Skip to content

Commit 3d0b181

Browse files
authored
Reimplement the data access via HalfStepHook (#7)
1 parent dcf77d4 commit 3d0b181

File tree

6 files changed

+211
-0
lines changed

6 files changed

+211
-0
lines changed

Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ FROM ssages/pysages-base:latest
22

33
COPY . hoomd-dlext
44
RUN cd hoomd-dlext && mkdir build && cd build && cmake .. && make install
5+
RUN python3 -c "import hoomd; import hoomd.dlext"

dlext/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ set(
44
${COMPONENT_NAME}_sources
55
SystemView.cc
66
PyDLExt.cc
7+
Sampler.cc
78
)
89

910
pybind11_add_module(${COMPONENT_NAME} SHARED ${${COMPONENT_NAME}_sources} NO_EXTRAS)

dlext/PyDLExt.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "PyDLExt.h"
55
#include "PyHalfStepHook.h"
6+
#include "Sampler.h"
67

78

89
using namespace sysview;
@@ -54,6 +55,7 @@ PYBIND11_MODULE(dlpack_extension, m)
5455
// Classes
5556
export_SystemView(m);
5657
export_PyHalfStepHook(m);
58+
export_Sampler(m);
5759

5860
// Methods
5961
m.def("positions_types", encapsulate<&positions_types>);

dlext/Sampler.cc

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
#include "Sampler.h"
2+
#include "hoomd/HOOMDMath.h"
3+
#include "dlpack.h"
4+
#include <stdexcept>
5+
6+
using namespace std;
7+
namespace py = pybind11;
8+
9+
const char* const kDLTensorCapsuleName = "dltensor";
10+
constexpr uint8_t kBits = std::is_same<Scalar, float>::value ? 32 : 64;
11+
12+
template <typename>
13+
constexpr DLDataType dtype();
14+
template <>
15+
constexpr DLDataType dtype<Scalar4>() { return DLDataType {kDLFloat, kBits, 1}; }
16+
template <>
17+
constexpr DLDataType dtype<Scalar3>() { return DLDataType {kDLFloat, kBits, 1}; }
18+
template <>
19+
constexpr DLDataType dtype<Scalar>() { return DLDataType {kDLFloat, kBits, 1}; }
20+
template <>
21+
constexpr DLDataType dtype<int3>() { return DLDataType {kDLInt, 32, 1}; }
22+
template <>
23+
constexpr DLDataType dtype<unsigned int>() { return DLDataType {kDLUInt, 32, 1}; }
24+
template <>
25+
constexpr DLDataType dtype<int>() { return DLDataType {kDLInt, 32, 1}; }
26+
27+
template <typename>
28+
constexpr int64_t stride1();
29+
template <>
30+
constexpr int64_t stride1<Scalar4>() { return 4; }
31+
template <>
32+
constexpr int64_t stride1<Scalar3>() { return 3; }
33+
template <>
34+
constexpr int64_t stride1<Scalar>() { return 1; }
35+
template <>
36+
constexpr int64_t stride1<int3>() { return 3; }
37+
template <>
38+
constexpr int64_t stride1<unsigned int>() { return 1; }
39+
40+
template <typename T>
41+
inline void* opaque(T* data) { return static_cast<void*>(data); }
42+
43+
inline py::capsule encapsulate(DLManagedTensor* dl_managed_tensor)
44+
{
45+
return py::capsule(dl_managed_tensor, kDLTensorCapsuleName);
46+
}
47+
48+
Sampler::Sampler(shared_ptr<SystemDefinition> sysdef,
49+
py::function python_update)
50+
:
51+
HalfStepHook(),
52+
m_python_update(python_update)
53+
{
54+
this->setSystemDefinition(sysdef);
55+
}
56+
57+
void Sampler::setSystemDefinition(shared_ptr<SystemDefinition> sysdef)
58+
{
59+
m_sysdef = sysdef;
60+
m_pdata = sysdef->getParticleData();
61+
m_exec_conf = m_pdata->getExecConf();
62+
}
63+
64+
void Sampler::run_on_data(py::function py_exec, const access_location::Enum location, const access_mode::Enum mode)
65+
{
66+
if(location == access_location::device and not m_exec_conf->isCUDAEnabled())
67+
throw runtime_error("Invalid request for device memory in non-cuda run.");
68+
69+
const bool on_device = location == access_location::device;
70+
71+
const ArrayHandle<Scalar4> pos(m_pdata->getPositions(), location, mode);
72+
auto pos_bridge = wrap<Scalar4, Scalar>(pos.data, on_device, 4 );
73+
auto pos_capsule = encapsulate(&pos_bridge.tensor);
74+
75+
const ArrayHandle<Scalar4> vel(m_pdata->getVelocities(), location, mode);
76+
auto vel_bridge = wrap<Scalar4, Scalar>(vel.data, on_device, 4 );
77+
auto vel_capsule = encapsulate(&vel_bridge.tensor);
78+
79+
const ArrayHandle<unsigned int> rtags(m_pdata->getRTags(), location, mode);
80+
auto rtags_bridge = wrap<unsigned int, unsigned int>(rtags.data, on_device, 1);
81+
auto rtags_capsule = encapsulate(&rtags_bridge.tensor);
82+
83+
const ArrayHandle<int3> img(m_pdata->getImages(), location, mode);
84+
auto img_bridge = wrap<int3, int>(img.data, on_device, 3);
85+
auto img_capsule = encapsulate(&img_bridge.tensor);
86+
87+
ArrayHandle<Scalar4> force(m_pdata->getNetForce(), location, access_mode::readwrite);
88+
auto force_bridge = wrap<Scalar4, Scalar>(force.data, on_device, 4 );
89+
auto force_capsule = encapsulate(&force_bridge.tensor);
90+
91+
py_exec(pos_capsule, vel_capsule, rtags_capsule, img_capsule, force_capsule);
92+
}
93+
94+
void Sampler::update(unsigned int timestep)
95+
{
96+
97+
// Accessing the handles here holds them valid until the block of this function.
98+
// This keeps them valid for the python function call
99+
auto location = m_exec_conf->isCUDAEnabled() ? access_location::device : access_location::host;
100+
101+
// const ArrayHandle<Scalar4> pos(m_pdata->getPositions(), location, access_mode::read);
102+
// auto pos_tensor = wrap<Scalar4, Scalar>(pos.data, 4 );
103+
// ArrayHandle<Scalar4> vel(m_pdata->getVelocities(), location, access_mode::read);
104+
// auto vel_tensor = wrap<Scalar4, Scalar>(vel.data, 4);
105+
// ArrayHandle<unsigned int> rtags(m_pdata->getRTags(), location, access_mode::read);
106+
// auto rtag_tensor = wrap<unsigned int, unsigned int>(rtags.data, 1);
107+
// ArrayHandle<int3> img(m_pdata->getImages(), location, access_mode::read);
108+
// auto img_tensor = wrap<int3, int>(img.data, 3);
109+
110+
// ArrayHandle<Scalar4> net_forces(m_pdata->getNetForce(), location, access_mode::readwrite);
111+
// auto force_tensor = wrap<Scalar4, Scalar>(net_forces.data, 4);
112+
113+
// m_python_update(pos_tensor, vel_tensor, rtag_tensor, img_tensor, force_tensor,
114+
// m_pdata->getGlobalBox());
115+
this->run_on_data(m_python_update, location, access_mode::read);
116+
}
117+
118+
template <typename TV, typename TS>
119+
DLDataBridge Sampler::wrap(TV* ptr,
120+
const bool on_device,
121+
const int64_t size2,
122+
const uint64_t offset,
123+
uint64_t stride1_offset) {
124+
assert((size2 >= 1)); // assert is a macro so the extra parentheses are requiered here
125+
126+
const unsigned int particle_number = this->m_pdata->getN();
127+
const int gpu_id = on_device ? m_exec_conf->getGPUIds()[0] : m_exec_conf->getRank();
128+
129+
DLDataBridge bridge;
130+
bridge.tensor.manager_ctx = NULL;
131+
bridge.tensor.deleter = NULL;
132+
133+
bridge.tensor.dl_tensor.data = opaque(ptr);
134+
bridge.tensor.dl_tensor.ctx = DLContext{on_device ? kDLGPU : kDLCPU, gpu_id};
135+
bridge.tensor.dl_tensor.dtype = dtype<TS>();
136+
137+
bridge.shape.push_back(particle_number);
138+
if (size2 > 1)
139+
bridge.shape.push_back(size2);
140+
141+
bridge.strides.push_back(stride1<TV>() + stride1_offset);
142+
if (size2 > 1)
143+
bridge.strides.push_back(1);
144+
145+
bridge.tensor.dl_tensor.ndim = bridge.shape.size();
146+
bridge.tensor.dl_tensor.dtype = dtype<TS>();
147+
bridge.tensor.dl_tensor.shape = reinterpret_cast<std::int64_t*>(bridge.shape.data());
148+
bridge.tensor.dl_tensor.strides = reinterpret_cast<std::int64_t*>(bridge.strides.data());
149+
bridge.tensor.dl_tensor.byte_offset = offset;
150+
151+
return bridge;
152+
}
153+
154+
155+
void export_Sampler(py::module& m)
156+
{
157+
py::class_<Sampler, std::shared_ptr<Sampler> >(m, "DLextSampler", py::base<HalfStepHook>())
158+
.def(py::init<std::shared_ptr<SystemDefinition>, py::function>())
159+
.def("run_on_data", &Sampler::run_on_data)
160+
;
161+
}

dlext/Sampler.h

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// SPDX-License-Identifier: MIT
2+
// This file is part of `hoomd-dlext`, see LICENSE.md
3+
4+
#ifndef SAMPLER_H
5+
#define SAMPLER_H
6+
7+
#include "hoomd/HalfStepHook.h"
8+
#include "hoomd/GlobalArray.h"
9+
#include <hoomd/extern/pybind/include/pybind11/pybind11.h>
10+
#include "dlpack.h"
11+
12+
struct DLDataBridge {
13+
std::vector<int64_t> shape;
14+
std::vector<int64_t> strides;
15+
DLManagedTensor tensor;
16+
};
17+
18+
19+
class Sampler : public HalfStepHook
20+
{
21+
public:
22+
//! Constructor
23+
Sampler(std::shared_ptr<SystemDefinition> sysdef, pybind11::function python_update);
24+
25+
virtual void setSystemDefinition(std::shared_ptr<SystemDefinition> sysdef) override;
26+
27+
//! Take one timestep forward
28+
virtual void update(unsigned int timestep) override;
29+
30+
// run a custom python function on data from hoomd
31+
// access_mode is ignored for forces. Forces are returned in readwrite mode always.
32+
void run_on_data(pybind11::function py_exec, const access_location::Enum location, const access_mode::Enum mode);
33+
34+
private:
35+
template<typename TS, typename TV>
36+
DLDataBridge wrap(TS* const ptr, const bool, const int64_t size2 = 1, const uint64_t offset=0, uint64_t stride1_offset = 0);
37+
pybind11::function m_python_update;
38+
std::shared_ptr<SystemDefinition> m_sysdef;
39+
std::shared_ptr<ParticleData> m_pdata;
40+
std::shared_ptr<const ExecutionConfiguration> m_exec_conf;
41+
};
42+
43+
void export_Sampler(pybind11::module& m);
44+
45+
#endif//SAMPLER_H

dlext/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,5 @@
2020
tags,
2121
rtags,
2222
velocities_masses,
23+
DLextSampler,
2324
)

0 commit comments

Comments
 (0)