Skip to content

Commit 96ce9ba

Browse files
committed
Move raw_pointer_cast to constructor of Array2D
1 parent cd5b439 commit 96ce9ba

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

genmetaballs/src/cuda/core/utils.cuh

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include <cuda/std/mdspan>
66
#include <cuda/std/span>
77
#include <cuda_runtime.h>
8+
#include <memory>
9+
#include <thrust/memory.h>
810

911
#define CUDA_CALLABLE __host__ __device__
1012

@@ -29,8 +31,9 @@ private:
2931

3032
public:
3133
// constructor
32-
CUDA_CALLABLE constexpr Array2D(T* data_ptr, uint32_t rows, uint32_t cols)
33-
: data_view_(data_ptr, rows, cols) {}
34+
template <typename Pointer>
35+
CUDA_CALLABLE constexpr Array2D(Pointer data_ptr, uint32_t rows, uint32_t cols)
36+
: data_view_(thrust::raw_pointer_cast(data_ptr), rows, cols) {}
3437

3538
// getting a 1D view of a specific row
3639
// this supports array2d[row][col] access pattern and range-based for loops
@@ -54,3 +57,10 @@ public:
5457
return data_view_.size();
5558
}
5659
}; // class Array2D
60+
61+
// Type deduction guide
62+
// if initialized with (Pointer, int, int), deduce T by looking at what raw_pointer_cast returns
63+
// so we can write Array2D(array_ptr, rows, cols) instead of Array2D<Type>(array_ptr, rows, cols)
64+
template <typename Pointer>
65+
Array2D(Pointer, uint32_t, uint32_t)
66+
-> Array2D<typename std::pointer_traits<Pointer>::element_type>;

tests/cpp_tests/test_utils.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ TYPED_TEST(Array2DTestFixture, CreateAndAccessArray2D) {
121121

122122
auto data = TypeParam(rows * cols);
123123
// create 2D view into the underlying data on host or device
124-
auto array2d = Array2D(thrust::raw_pointer_cast(data.data()), rows, cols);
124+
auto array2d = Array2D(data.data(), rows, cols);
125125

126126
if constexpr (std::is_same_v<TypeParam, std::vector<float>>) {
127127
for (auto i = 0; i < rows - 1; i++) {
@@ -164,7 +164,7 @@ TYPED_TEST(Array2DTestFixture, ViewModifiesUnderlyingData) {
164164
uint32_t rows = 3;
165165
uint32_t cols = 4;
166166
auto data = TypeParam(rows * cols, 0.0f);
167-
auto array2d = Array2D(thrust::raw_pointer_cast(data.data()), rows, cols);
167+
auto array2d = Array2D(data.data(), rows, cols);
168168

169169
// Modify through view
170170
array2d[1][2] = 42.5f;
@@ -184,8 +184,8 @@ TYPED_TEST(Array2DTestFixture, MultipleViewsOfSameData) {
184184
uint32_t rows = 2;
185185
uint32_t cols = 3;
186186
auto data = TypeParam(rows * cols, 0.0f);
187-
auto view1 = Array2D(thrust::raw_pointer_cast(data.data()), rows, cols);
188-
auto view2 = Array2D(thrust::raw_pointer_cast(data.data()), rows, cols);
187+
auto view1 = Array2D(data.data(), rows, cols);
188+
auto view2 = Array2D(data.data(), rows, cols);
189189

190190
// Modify through view1
191191
view1[0][0] = 100.0f;

0 commit comments

Comments
 (0)