Skip to content

Commit 685c791

Browse files
authored
Merge pull request #3 from SCOREC/multi-slice-cab-factory
CabanaSliceFactory now works with multiple types.
2 parents 03081e8 + 790cedd commit 685c791

File tree

2 files changed

+38
-24
lines changed

2 files changed

+38
-24
lines changed

src/SliceWrapper.hpp

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ struct SliceWrapper {
1111
SliceWrapper(SliceType st) : st_(st) {}
1212

1313
KOKKOS_INLINE_FUNCTION
14-
T& access(const int s, const int a, int i) const {
15-
return st_.access(s,a,i);
14+
T& access(const int s, const int a) const {
15+
return st_.access(s,a);
1616
}
1717
int arraySize(int s) {
1818
return st_.arraySize(s);
@@ -24,23 +24,34 @@ struct SliceWrapper {
2424

2525
using namespace Cabana;
2626

27-
template <class ExecutionSpace, class MemorySpace, class T, int width, int vecLen>
27+
template <class ExecutionSpace, class MemorySpace, class... Ts>
2828
class CabSliceFactory {
29+
static constexpr int vecLen = Impl::PerformanceTraits<ExecutionSpace>::vector_length/8;
30+
using TypeTuple = std::tuple<Ts...>;
2931
using DeviceType = Kokkos::Device<ExecutionSpace, MemorySpace>;
30-
using DataTypes = Cabana::MemberTypes<T[width]>;
32+
using DataTypes = Cabana::MemberTypes<Ts...>;
33+
using soa_t = SoA<DataTypes, vecLen>;
34+
35+
template <class T, int stride>
3136
using member_slice_t =
32-
Cabana::Slice<T[width], DeviceType,
37+
Cabana::Slice<T, DeviceType,
3338
Cabana::DefaultAccessMemory,
34-
vecLen, width*vecLen>;
35-
using wrapper_slice_t = SliceWrapper<member_slice_t, T>;
39+
vecLen, stride>;
40+
41+
template <class T, int stride>
42+
using wrapper_slice_t = SliceWrapper<member_slice_t<T, stride>, T>;
3643

3744
Cabana::AoSoA<DataTypes, DeviceType, vecLen> aosoa;
3845

3946
public:
40-
wrapper_slice_t makeSliceCab() {
41-
auto slice0 = Cabana::slice<0>(aosoa);
42-
return wrapper_slice_t(std::move(slice0));
47+
template <std::size_t index>
48+
auto makeSliceCab() {
49+
using type = std::tuple_element_t<index, TypeTuple>;
50+
const int stride = (vecLen * sizeof(soa_t)) / (4 * sizeof(type));
51+
auto slice = Cabana::slice<index>(aosoa);
52+
return wrapper_slice_t< type, stride >(std::move(slice));
4353
}
54+
4455
CabSliceFactory(int n) : aosoa("sliceAoSoA", n) {}
4556
};
4657

test/SliceWrapper.cpp

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,37 @@
33
int main(int argc, char* argv[]) {
44
// AoSoA parameters
55
const int vecLen = 4;
6-
const int width = 1;
76
int num_tuples = 10;
87

98
Kokkos::ScopeGuard scope_guard(argc, argv);
109

11-
using member_type = double;
12-
using DataTypes = Cabana::MemberTypes<member_type[width]>;
13-
using ExecutionSpace = Kokkos::Cuda;
14-
using MemorySpace = Kokkos::CudaSpace;
15-
10+
using ExecutionSpace = Kokkos::DefaultExecutionSpace;
11+
using MemorySpace = ExecutionSpace::memory_space;
1612

1713
// Slice Wrapper Factory
1814
CabSliceFactory<ExecutionSpace, MemorySpace,
19-
member_type, width, vecLen> cabSliceFactory(num_tuples);
15+
double, int, float, char> cabSliceFactory(num_tuples);
2016

21-
auto slice_wrapper = cabSliceFactory.makeSliceCab();
17+
auto slice_wrapper0 = cabSliceFactory.makeSliceCab<0>();
18+
auto slice_wrapper1 = cabSliceFactory.makeSliceCab<1>();
19+
auto slice_wrapper2 = cabSliceFactory.makeSliceCab<2>();
20+
auto slice_wrapper3 = cabSliceFactory.makeSliceCab<3>();
2221

2322
// simd_parallel_for setup
2423
Cabana::SimdPolicy<vecLen, ExecutionSpace> simd_policy(0, num_tuples);
2524

2625
// kernel that reads and writes
2726
auto vector_kernel = KOKKOS_LAMBDA(const int s, const int a) {
28-
for (int i = 0; i < width; i++) {
29-
printf("s: %d, a: %d, i: %d\n", s,a,i);
30-
double x = 42/(s+a+1.3);
31-
slice_wrapper.access(s,a,i) = x;
32-
printf("value: %lf\n", slice_wrapper.access(s,a,i));
33-
}
27+
printf("s: %d, a: %d\n", s,a);
28+
double x = 42/(s+a+1.3);
29+
slice_wrapper0.access(s,a) = x;
30+
slice_wrapper1.access(s,a) = s+a;
31+
slice_wrapper2.access(s,a) = float(x);
32+
slice_wrapper3.access(s,a) = 'a'+s+a;
33+
printf("SW0 value: %lf\n", slice_wrapper0.access(s,a));
34+
printf("SW1 value: %d\n", slice_wrapper1.access(s,a));
35+
printf("SW2 value: %f\n", slice_wrapper2.access(s,a));
36+
printf("SW3 value: %c\n", slice_wrapper3.access(s,a));
3437
};
3538

3639
Cabana::simd_parallel_for(simd_policy, vector_kernel, "parallel_for_cabSliceFactory");

0 commit comments

Comments
 (0)