1616#include < limits> // std::numeric_limits<T>::signaling_NaN
1717#include < sstream>
1818
19+ #include " pluto/pluto.h"
20+
1921#include " atlas/array/ArrayDataStore.h"
2022#include " atlas/library/Library.h"
2123#include " atlas/library/config.h"
2426#include " atlas/runtime/Log.h"
2527#include " eckit/log/Bytes.h"
2628
27- #include " hic/hic.h"
28-
29-
3029#define ATLAS_ACC_DEBUG 0
3130
3231// ------------------------------------------------------------------------------
@@ -94,26 +93,15 @@ template <typename Value>
9493void initialise (Value[], size_t ) {}
9594#endif
9695
97- static int devices () {
98- static int devices_ = [](){
99- int n = 0 ;
100- auto err = hicGetDeviceCount (&n);
101- if (err != hicSuccess) {
102- n = 0 ;
103- static_cast <void >(hicGetLastError ());
104- }
105- return n;
106- }();
107- return devices_;
108- }
109-
11096template <typename Value>
11197class DataStore : public ArrayDataStore {
11298public:
113- DataStore (size_t size): size_(size) {
99+ DataStore (size_t size): size_(size),
100+ host_allocator_{pluto::host_resource ()},
101+ device_allocator_{pluto::device_resource ()} {
114102 allocateHost ();
115103 initialise (host_data_, size_);
116- if (ATLAS_HAVE_GPU && devices ()) {
104+ if (ATLAS_HAVE_GPU && pluto:: devices ()) {
117105 device_updated_ = false ;
118106 }
119107 else {
@@ -127,25 +115,19 @@ class DataStore : public ArrayDataStore {
127115 }
128116
129117 void updateDevice () const override {
130- if (ATLAS_HAVE_GPU && devices ()) {
118+ if (ATLAS_HAVE_GPU && pluto:: devices ()) {
131119 if (not device_allocated_) {
132120 allocateDevice ();
133121 }
134- hicError_t err = hicMemcpy (device_data_, host_data_, size_*sizeof (Value), hicMemcpyHostToDevice);
135- if (err != hicSuccess) {
136- throw_AssertionFailed (" Failed to updateDevice: " +std::string (hicGetErrorString (err)), Here ());
137- }
122+ pluto::copy_host_to_device (device_data_, host_data_, size_);
138123 device_updated_ = true ;
139124 }
140125 }
141126
142127 void updateHost () const override {
143128 if constexpr (ATLAS_HAVE_GPU) {
144129 if (device_allocated_) {
145- hicError_t err = hicMemcpy (host_data_, device_data_, size_*sizeof (Value), hicMemcpyDeviceToHost);
146- if (err != hicSuccess) {
147- throw_AssertionFailed (" Failed to updateHost: " +std::string (hicGetErrorString (err)), Here ());
148- }
130+ pluto::copy_device_to_host (host_data_, device_data_, size_);
149131 host_updated_ = true ;
150132 }
151133 }
@@ -174,32 +156,24 @@ class DataStore : public ArrayDataStore {
174156 bool deviceAllocated () const override { return device_allocated_; }
175157
176158 void allocateDevice () const override {
177- if (ATLAS_HAVE_GPU && devices ()) {
159+ if (ATLAS_HAVE_GPU && pluto:: devices ()) {
178160 if (device_allocated_) {
179161 return ;
180162 }
181163 if (size_) {
182- hicError_t err = hicMalloc ((void **)&device_data_, sizeof (Value)*size_);
183- if (err != hicSuccess) {
184- throw_AssertionFailed (" Failed to allocate GPU memory: " + std::string (hicGetErrorString (err)), Here ());
185- }
164+ device_data_ = device_allocator_.allocate (size_);
186165 device_allocated_ = true ;
187166 accMap ();
188167 }
189168 }
190169 }
191170
192171 void deallocateDevice () const override {
193- if constexpr (ATLAS_HAVE_GPU) {
194- if (device_allocated_) {
195- accUnmap ();
196- hicError_t err = hicFree (device_data_);
197- if (err != hicSuccess) {
198- throw_AssertionFailed (" Failed to deallocate GPU memory: " + std::string (hicGetErrorString (err)), Here ());
199- }
200- device_data_ = nullptr ;
201- device_allocated_ = false ;
202- }
172+ if (device_allocated_) {
173+ accUnmap ();
174+ device_allocator_.deallocate (device_data_,size_);
175+ device_data_ = nullptr ;
176+ device_allocated_ = false ;
203177 }
204178 }
205179
@@ -259,36 +233,22 @@ class DataStore : public ArrayDataStore {
259233 throw_Exception (ss.str (), loc);
260234 }
261235
262- void alloc_aligned (Value*& ptr, size_t n) {
263- if (n > 0 ) {
264- const size_t alignment = 64 * sizeof (Value);
265- size_t bytes = sizeof (Value) * n;
266- MemoryHighWatermark::instance () += bytes;
267-
268- int err = posix_memalign ((void **)&ptr, alignment, bytes);
269- if (err) {
270- throw_AllocationFailed (bytes, Here ());
271- }
272- }
273- else {
274- ptr = nullptr ;
275- }
276- }
277-
278- void free_aligned (Value*& ptr) {
279- if (ptr) {
280- free (ptr);
281- ptr = nullptr ;
282- MemoryHighWatermark::instance () -= footprint ();
283- }
284- }
285-
286236 void allocateHost () {
287- alloc_aligned (host_data_, size_);
237+ if (size_ > 0 ) {
238+ MemoryHighWatermark::instance () += footprint ();
239+ host_data_ = host_allocator_.allocate (size_);
240+ }
241+ else {
242+ host_data_ = nullptr ;
243+ }
288244 }
289245
290246 void deallocateHost () {
291- free_aligned (host_data_);
247+ if (host_data_) {
248+ host_allocator_.deallocate (host_data_, size_);
249+ host_data_ = nullptr ;
250+ MemoryHighWatermark::instance () -= footprint ();
251+ }
292252 }
293253
294254 size_t footprint () const { return sizeof (Value) * size_; }
@@ -302,6 +262,8 @@ class DataStore : public ArrayDataStore {
302262 mutable bool device_allocated_{false };
303263 mutable bool acc_mapped_{false };
304264
265+ pluto::allocator<Value> host_allocator_;
266+ mutable pluto::allocator<Value> device_allocator_;
305267};
306268
307269// ------------------------------------------------------------------------------
@@ -311,22 +273,23 @@ class WrappedDataStore : public ArrayDataStore {
311273public:
312274
313275 void init_device () {
314- if (ATLAS_HAVE_GPU && devices ()) {
276+ if (ATLAS_HAVE_GPU && pluto:: devices ()) {
315277 device_updated_ = false ;
316278 }
317279 else {
318280 device_data_ = host_data_;
319281 }
320282 }
321283
322- WrappedDataStore (Value* host_data, size_t size): host_data_(host_data), size_(size) {
284+ WrappedDataStore (Value* host_data, size_t size): host_data_(host_data), size_(size),
285+ device_allocator_{pluto::device_resource ()} {
323286 init_device ();
324287 }
325288
326289 WrappedDataStore (Value* host_data, const ArraySpec& spec):
327290 host_data_ (host_data),
328- size_ (spec.size())
329- {
291+ size_ (spec.size()),
292+ device_allocator_{ pluto::device_resource ()} {
330293 init_device ();
331294 contiguous_ = spec.contiguous ();
332295 if (! contiguous_) {
@@ -363,25 +326,18 @@ class WrappedDataStore : public ArrayDataStore {
363326 }
364327
365328 void updateDevice () const override {
366- if (ATLAS_HAVE_GPU && devices ()) {
329+ if (ATLAS_HAVE_GPU && pluto:: devices ()) {
367330 if (not device_allocated_) {
368331 allocateDevice ();
369332 }
370333 if (contiguous_) {
371- hicError_t err = hicMemcpy (device_data_, host_data_, size_*sizeof (Value), hicMemcpyHostToDevice);
372- if (err != hicSuccess) {
373- throw_AssertionFailed (" Failed to updateDevice: " +std::string (hicGetErrorString (err)), Here ());
374- }
334+ pluto::copy_host_to_device (device_data_, host_data_, size_);
375335 }
376336 else {
377- hicError_t err = hicMemcpy2D (
378- device_data_, memcpy_h2d_pitch_ * sizeof (Value),
379- host_data_, memcpy_d2h_pitch_ * sizeof (Value),
380- memcpy_width_ * sizeof (Value), memcpy_height_,
381- hicMemcpyHostToDevice);
382- if (err != hicSuccess) {
383- throw_AssertionFailed (" Failed to updateDevice: " +std::string (hicGetErrorString (err)), Here ());
384- }
337+ pluto::copy_host_to_device_2D (
338+ device_data_, memcpy_h2d_pitch_,
339+ host_data_, memcpy_d2h_pitch_,
340+ memcpy_width_, memcpy_height_);
385341 }
386342 device_updated_ = true ;
387343 }
@@ -391,20 +347,13 @@ class WrappedDataStore : public ArrayDataStore {
391347 if constexpr (ATLAS_HAVE_GPU) {
392348 if (device_allocated_) {
393349 if (contiguous_) {
394- hicError_t err = hicMemcpy (host_data_, device_data_, size_*sizeof (Value), hicMemcpyDeviceToHost);
395- if (err != hicSuccess) {
396- throw_AssertionFailed (" Failed to updateHost: " +std::string (hicGetErrorString (err)), Here ());
397- }
350+ pluto::copy_device_to_host (host_data_, device_data_, size_);
398351 }
399352 else {
400- hicError_t err = hicMemcpy2D (
401- host_data_, memcpy_d2h_pitch_ * sizeof (Value),
402- device_data_, memcpy_h2d_pitch_ * sizeof (Value),
403- memcpy_width_ * sizeof (Value), memcpy_height_,
404- hicMemcpyDeviceToHost);
405- if (err != hicSuccess) {
406- throw_AssertionFailed (" Failed to updateHost: " +std::string (hicGetErrorString (err)), Here ());
407- }
353+ pluto::copy_device_to_host_2D (
354+ host_data_, memcpy_d2h_pitch_ ,
355+ device_data_, memcpy_h2d_pitch_,
356+ memcpy_width_, memcpy_height_);
408357 }
409358 host_updated_ = true ;
410359 }
@@ -435,15 +384,12 @@ class WrappedDataStore : public ArrayDataStore {
435384 bool deviceAllocated () const override { return device_allocated_; }
436385
437386 void allocateDevice () const override {
438- if (ATLAS_HAVE_GPU && devices ()) {
387+ if (ATLAS_HAVE_GPU && pluto:: devices ()) {
439388 if (device_allocated_) {
440389 return ;
441390 }
442391 if (size_) {
443- hicError_t err = hicMalloc ((void **)&device_data_, sizeof (Value)*size_);
444- if (err != hicSuccess) {
445- throw_AssertionFailed (" Failed to allocate GPU memory: " + std::string (hicGetErrorString (err)), Here ());
446- }
392+ device_data_ = device_allocator_.allocate (size_);
447393 device_allocated_ = true ;
448394 if (contiguous_) {
449395 accMap ();
@@ -453,18 +399,13 @@ class WrappedDataStore : public ArrayDataStore {
453399 }
454400
455401 void deallocateDevice () const override {
456- if constexpr (ATLAS_HAVE_GPU) {
457- if (device_allocated_) {
458- if (contiguous_) {
459- accUnmap ();
460- }
461- hicError_t err = hicFree (device_data_);
462- if (err != hicSuccess) {
463- throw_AssertionFailed (" Failed to deallocate GPU memory: " + std::string (hicGetErrorString (err)), Here ());
464- }
465- device_data_ = nullptr ;
466- device_allocated_ = false ;
402+ if (device_allocated_) {
403+ if (contiguous_) {
404+ accUnmap ();
467405 }
406+ device_allocator_.deallocate (device_data_, size_);
407+ device_data_ = nullptr ;
408+ device_allocated_ = false ;
468409 }
469410 }
470411
@@ -505,7 +446,6 @@ class WrappedDataStore : public ArrayDataStore {
505446 }
506447
507448 void accUnmap () const override {
508- #if ATLAS_HAVE_ACC
509449 if (acc_mapped_) {
510450 ATLAS_ASSERT (atlas::acc::is_present (host_data_, size_ * sizeof (Value)));
511451 if constexpr (ATLAS_ACC_DEBUG) {
@@ -514,7 +454,6 @@ class WrappedDataStore : public ArrayDataStore {
514454 atlas::acc::unmap (host_data_);
515455 acc_mapped_ = false ;
516456 }
517- #endif
518457 }
519458
520459private:
@@ -532,6 +471,8 @@ class WrappedDataStore : public ArrayDataStore {
532471 mutable bool device_updated_{true };
533472 mutable bool device_allocated_{false };
534473 mutable bool acc_mapped_{false };
474+
475+ mutable pluto::allocator<Value> device_allocator_;
535476};
536477
537478} // namespace native
0 commit comments