1+ /* ************************************************************************
2+ * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3+ * License for AMD contributions = MIT. See LICENSE for more information
4+ *************************************************************************/
5+
6+ #include < hip/hip_runtime.h>
7+ #include < rocshmem.hpp>
8+
9+ #include " ../util/logging_hip.h"
10+ #include " rocshmem_waitkernel.hpp"
11+
12+ using namespace rocshmem ;
13+
14+ __global__ void wait_until_on_stream_and_reset (uint64_t *wait_flag,
15+ uint64_t wait_value,
16+ uint64_t signal_reset) {
17+ rocshmem_ulonglong_wait_until ((unsigned long long *)wait_flag,
18+ ROCSHMEM_CMP_EQ,
19+ (unsigned long long )wait_value);
20+ }
21+
22+ __global__ void rocshmem_putmem_signal_kernel (void * dst_ptr, const void * src_ptr,
23+ size_t nelement, uint64_t * sig_addr,
24+ uint64_t sigval, int peer) {
25+ if (threadIdx.x == 0 && blockIdx.x == 0 ) {
26+ rocshmem_putmem (dst_ptr, src_ptr, nelement, peer);
27+ rocshmem_fence ();
28+ rocshmem_ulonglong_p ((unsigned long long *)sig_addr,
29+ (unsigned long long )sigval,
30+ peer);
31+ }
32+ }
33+
34+ void te_rocshmem_putmem_signal (void * dst_ptr, const void * src_ptr, size_t nelement,
35+ uint64_t * sig_addr, uint64_t sigval, int peer,
36+ hipStream_t cur_stream) {
37+ hipLaunchKernelGGL (rocshmem_putmem_signal_kernel,
38+ dim3 (1 ), dim3 (1 ), 0 , cur_stream,
39+ dst_ptr, src_ptr, nelement, sig_addr,
40+ sigval, peer);
41+ }
42+
43+ void te_rocshmem_wait_on_stream (uint64_t * sig_addr,
44+ WaitKind wait_kind,
45+ hipStream_t cur_stream) {
46+ uint64_t wait_value = 1 ;
47+ uint64_t signal_reset = 0 ;
48+
49+ NVTE_CHECK (wait_kind >= WaitKind::KERNEL_WAIT &&
50+ wait_kind <= WaitKind::STREAM_WAIT,
51+ " Invalid wait kind" );
52+
53+ switch (wait_kind) {
54+ // ### wait_until_on_stream not yet implemented for rocshmem ###
55+ // ### KernelWait is robust but slightly slower due to launch ###
56+ case WaitKind::ROCSHMEM_WAIT:
57+ // rocshmem__ulonglong_wait_until_on_stream(sig_addr,
58+ // ROCSHMEM_CMP_EQ,
59+ // wait_value,
60+ // cur_stream);
61+ // hipStreamWriteValue64(cur_stream,
62+ // reinterpret_cast<hipDeviceptr_t>(sig_addr),
63+ // signal_reset, 0);
64+ // break;
65+ case WaitKind::KERNEL_WAIT:
66+ hipLaunchKernelGGL (wait_until_on_stream_and_reset,
67+ dim3 (1 ), dim3 (1 ), 0 , cur_stream,
68+ sig_addr, wait_value, signal_reset);
69+ hipStreamWriteValue64 (cur_stream,
70+ reinterpret_cast <hipDeviceptr_t>(sig_addr),
71+ signal_reset, 0 );
72+ break ;
73+ case WaitKind::STREAM_WAIT:
74+ hipStreamWaitValue64 (cur_stream,
75+ reinterpret_cast <hipDeviceptr_t>(sig_addr),
76+ wait_value, hipStreamWaitValueGte);
77+ hipStreamWriteValue64 (cur_stream,
78+ reinterpret_cast <hipDeviceptr_t>(sig_addr),
79+ signal_reset, 0 );
80+ break ;
81+ }
82+ }
83+
84+ int te_rocshmem_init_thread (int required, int * provided) {
85+ if (required == 0 && provided == nullptr ) {
86+ rocshmem_init ();
87+ return 0 ;
88+ } else {
89+ return rocshmem_init_thread (required, provided);
90+ }
91+ }
92+
93+ void te_rocshmem_finalize () {
94+ rocshmem_finalize ();
95+ }
96+
97+ int te_rocshmem_my_pe () {
98+ return rocshmem_my_pe ();
99+ }
100+
101+ int te_rocshmem_n_pes () {
102+ return rocshmem_n_pes ();
103+ }
104+
105+ void * te_rocshmem_malloc (size_t size) {
106+ return rocshmem_malloc (size);
107+ }
108+
109+ void te_rocshmem_free (void * ptr) {
110+ rocshmem_free (ptr);
111+ }
112+
113+ void te_rocshmem_wait_until (uint64_t * signal_addr, uint64_t expected_value,
114+ hipStream_t stream);
0 commit comments