5454
5555 */
5656
57- #include < iostream>
58-
59- #include < hip/hip_runtime_api.h>
60- #include < hip/hip_runtime.h>
6157#include < rocshmem/rocshmem.hpp>
6258
63- #define CHECK_HIP (condition ) { \
64- hipError_t error = condition; \
65- if (error != hipSuccess){ \
66- fprintf (stderr," HIP error: %d line: %d\n " , error, __LINE__); \
67- MPI_Abort (MPI_COMM_WORLD, error); \
68- } \
69- }
59+ #include " util.h"
7060
7161using namespace rocshmem ;
7262
@@ -95,20 +85,20 @@ __global__ void simple_put_signal_test(uint64_t *data, uint64_t *message, size_t
9585
9686int main (int argc, char **argv)
9787{
98- int rank = rocshmem_my_pe ();
99- int ndevices, my_device = 0 ;
100- CHECK_HIP (hipGetDeviceCount (&ndevices));
101- my_device = rank % ndevices;
102- CHECK_HIP (hipSetDevice (my_device));
10388 int nelem = MAX_ELEM;
10489
10590 if (argc > 1 ) {
10691 nelem = atoi (argv[1 ]);
10792 }
10893
94+ CHECK_HIP (hipSetDevice (get_launcher_local_rank ()));
95+
10996 rocshmem_init ();
97+
98+ int my_pe = rocshmem_my_pe ();
11099 int npes = rocshmem_n_pes ();
111- int dst_pe = (rank + 1 ) % npes;
100+
101+ int dst_pe = (my_pe + 1 ) % npes;
112102 uint64_t *message = (uint64_t *)rocshmem_malloc (nelem * sizeof (uint64_t ));
113103 uint64_t *data = (uint64_t *)rocshmem_malloc (nelem * sizeof (uint64_t ));
114104 uint64_t *sig_addr = (uint64_t *)rocshmem_malloc (sizeof (uint64_t ));
@@ -123,14 +113,14 @@ int main (int argc, char **argv)
123113 }
124114
125115 for (int i=0 ; i<nelem; i++) {
126- message[i] = rank ;
116+ message[i] = my_pe ;
127117 }
128118
129119 CHECK_HIP (hipMemset (data, 0 , (nelem * sizeof (uint64_t ))));
130120 CHECK_HIP (hipDeviceSynchronize ());
131121
132122 int threadsPerBlock=256 ;
133- simple_put_signal_test<<<dim3 (1 ), dim3 (threadsPerBlock), 0 , 0 >>>(data, message, nelem, sig_addr, rank , dst_pe);
123+ simple_put_signal_test<<<dim3 (1 ), dim3 (threadsPerBlock), 0 , 0 >>>(data, message, nelem, sig_addr, my_pe , dst_pe);
134124 rocshmem_barrier_all ();
135125 CHECK_HIP (hipDeviceSynchronize ());
136126
@@ -139,11 +129,11 @@ int main (int argc, char **argv)
139129 if (data[i] != 0 ) {
140130 pass = false ;
141131#if VERBOSE
142- printf (" [%d] Error in element %d expected 0 got %d\n " , rank , i, dst[i]);
132+ printf (" [%d] Error in element %d expected 0 got %d\n " , my_pe , i, dst[i]);
143133#endif
144134 }
145135 }
146- printf (" [%d] Test %s \t %s\n " , rank , argv[0 ], pass ? " [PASS]" : " [FAIL]" );
136+ printf (" [%d] Test %s \t %s\n " , my_pe , argv[0 ], pass ? " [PASS]" : " [FAIL]" );
147137
148138 rocshmem_free (data);
149139 rocshmem_free (message);
0 commit comments