|
3 | 3 | int main(int argc, char* argv[]) { |
4 | 4 | // AoSoA parameters |
5 | 5 | const int vecLen = 4; |
6 | | - const int width = 1; |
7 | 6 | int num_tuples = 10; |
8 | 7 |
|
9 | 8 | Kokkos::ScopeGuard scope_guard(argc, argv); |
10 | 9 |
|
11 | | - using member_type = double; |
12 | | - using DataTypes = Cabana::MemberTypes<member_type[width]>; |
13 | | - using ExecutionSpace = Kokkos::Cuda; |
14 | | - using MemorySpace = Kokkos::CudaSpace; |
15 | | - |
| 10 | + using ExecutionSpace = Kokkos::DefaultExecutionSpace; |
| 11 | + using MemorySpace = ExecutionSpace::memory_space; |
16 | 12 |
|
17 | 13 | // Slice Wrapper Factory |
18 | 14 | CabSliceFactory<ExecutionSpace, MemorySpace, |
19 | | - member_type, width, vecLen> cabSliceFactory(num_tuples); |
| 15 | + double, int, float, char> cabSliceFactory(num_tuples); |
20 | 16 |
|
21 | | - auto slice_wrapper = cabSliceFactory.makeSliceCab(); |
| 17 | + auto slice_wrapper0 = cabSliceFactory.makeSliceCab<0>(); |
| 18 | + auto slice_wrapper1 = cabSliceFactory.makeSliceCab<1>(); |
| 19 | + auto slice_wrapper2 = cabSliceFactory.makeSliceCab<2>(); |
| 20 | + auto slice_wrapper3 = cabSliceFactory.makeSliceCab<3>(); |
22 | 21 |
|
23 | 22 | // simd_parallel_for setup |
24 | 23 | Cabana::SimdPolicy<vecLen, ExecutionSpace> simd_policy(0, num_tuples); |
25 | 24 |
|
26 | 25 | // kernel that reads and writes |
27 | 26 | auto vector_kernel = KOKKOS_LAMBDA(const int s, const int a) { |
28 | | - for (int i = 0; i < width; i++) { |
29 | | - printf("s: %d, a: %d, i: %d\n", s,a,i); |
30 | | - double x = 42/(s+a+1.3); |
31 | | - slice_wrapper.access(s,a,i) = x; |
32 | | - printf("value: %lf\n", slice_wrapper.access(s,a,i)); |
33 | | - } |
| 27 | + printf("s: %d, a: %d\n", s,a); |
| 28 | + double x = 42/(s+a+1.3); |
| 29 | + slice_wrapper0.access(s,a) = x; |
| 30 | + slice_wrapper1.access(s,a) = s+a; |
| 31 | + slice_wrapper2.access(s,a) = float(x); |
| 32 | + slice_wrapper3.access(s,a) = 'a'+s+a; |
| 33 | + printf("SW0 value: %lf\n", slice_wrapper0.access(s,a)); |
| 34 | + printf("SW1 value: %d\n", slice_wrapper1.access(s,a)); |
| 35 | + printf("SW2 value: %f\n", slice_wrapper2.access(s,a)); |
| 36 | + printf("SW3 value: %c\n", slice_wrapper3.access(s,a)); |
34 | 37 | }; |
35 | 38 |
|
36 | 39 | Cabana::simd_parallel_for(simd_policy, vector_kernel, "parallel_for_cabSliceFactory"); |
|
0 commit comments