@@ -304,26 +304,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
304304 " Clean up and finalize the NVSHMEM communication backend and free associated resources" ,
305305 py::call_guard<py::gil_scoped_release>());
306306#else
307- // rocshmem functions
308- m.def (" init_rocshmem_backend" , &transformer_engine::pytorch::init_rocshmem_backend,
309- " Initialize ROCSHMEM backend with Pytorch distributed process groups" ,
310- py::call_guard<py::gil_scoped_release>());
311- m.def (" create_rocshmem_tensor" , &transformer_engine::pytorch::create_rocshmem_tensor,
312- " Create a tensor in ROCSHMEM shared memory" , py::call_guard<py::gil_scoped_release>());
313- m.def (" rocshmem_send_on_current_stream" ,
314- &transformer_engine::pytorch::rocshmem_send_on_current_stream,
315- " Asynchronously send tensor data to a remote PE using ROCSHMEM on the current HIP stream" ,
316- py::call_guard<py::gil_scoped_release>());
317- m.def (" rocshmem_wait_on_current_stream" ,
318- &transformer_engine::pytorch::rocshmem_wait_on_current_stream,
319- " Wait for a signal value to be updated by a remote PE using ROCSHMEM on the current HIP "
320- " stream" ,
321- py::call_guard<py::gil_scoped_release>());
322- m.def (" rocshmem_finalize" , &transformer_engine::pytorch::rocshmem_finalize,
323- " Clean up and finalize the ROCSHMEM communication backend and free associated resources" ,
324- py::call_guard<py::gil_scoped_release>());
325-
326- // nvshmem wrappers
307+ // nvshmem/rocshmem wrappers
327308 m.def (" init_nvshmem_backend" , &transformer_engine::pytorch::init_rocshmem_backend,
328309 " Initialize ROCSHMEM backend with Pytorch distributed process groups" ,
329310 py::call_guard<py::gil_scoped_release>());
0 commit comments