Skip to content

Commit 4e1810d

Browse files
committed
add rocshmem wait warning
1 parent 7aae806 commit 4e1810d

File tree

2 files changed

+3
-20
lines changed

2 files changed

+3
-20
lines changed

transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hip

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ void te_rocshmem_wait_on_stream(uint64_t* sig_addr,
5454
// ### wait_until_on_stream not yet implemented for rocshmem ###
5555
// ### KernelWait is robust but slightly slower due to launch ###
5656
case WaitKind::ROCSHMEM_WAIT:
57+
printf("WARNING: rocshmem wait is not implemented yet, defaulting to
58+
kernel wait.\n");
5759
// rocshmem__ulonglong_wait_until_on_stream(sig_addr,
5860
// ROCSHMEM_CMP_EQ,
5961
// wait_value,

transformer_engine/pytorch/csrc/extensions/pybind.cpp

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -304,26 +304,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
304304
"Clean up and finalize the NVSHMEM communication backend and free associated resources",
305305
py::call_guard<py::gil_scoped_release>());
306306
#else
307-
// rocshmem functions
308-
m.def("init_rocshmem_backend", &transformer_engine::pytorch::init_rocshmem_backend,
309-
"Initialize ROCSHMEM backend with Pytorch distributed process groups",
310-
py::call_guard<py::gil_scoped_release>());
311-
m.def("create_rocshmem_tensor", &transformer_engine::pytorch::create_rocshmem_tensor,
312-
"Create a tensor in ROCSHMEM shared memory", py::call_guard<py::gil_scoped_release>());
313-
m.def("rocshmem_send_on_current_stream",
314-
&transformer_engine::pytorch::rocshmem_send_on_current_stream,
315-
"Asynchronously send tensor data to a remote PE using ROCSHMEM on the current HIP stream",
316-
py::call_guard<py::gil_scoped_release>());
317-
m.def("rocshmem_wait_on_current_stream",
318-
&transformer_engine::pytorch::rocshmem_wait_on_current_stream,
319-
"Wait for a signal value to be updated by a remote PE using ROCSHMEM on the current HIP "
320-
"stream",
321-
py::call_guard<py::gil_scoped_release>());
322-
m.def("rocshmem_finalize", &transformer_engine::pytorch::rocshmem_finalize,
323-
"Clean up and finalize the ROCSHMEM communication backend and free associated resources",
324-
py::call_guard<py::gil_scoped_release>());
325-
326-
// nvshmem wrappers
307+
// nvshmem/rocshmem wrappers
327308
m.def("init_nvshmem_backend", &transformer_engine::pytorch::init_rocshmem_backend,
328309
"Initialize ROCSHMEM backend with Pytorch distributed process groups",
329310
py::call_guard<py::gil_scoped_release>());

0 commit comments

Comments
 (0)