Skip to content

Commit 617fa07

Browse files
authored
[Multidevice] Tma bulk copy p2p runtime examples (#6011)
## What Add a Hopper TMA (`cp.async.bulk`) copy kernel in `csrc/multidevice/tma_copy.cu` and validate it across three memory source/destination types: - local GMEM - peer symmetric memory. It means TMA can write from local shared memory to remote global memory. - NVLS multicast pointers. It means that by using the multicast ptr as the destination of the TMA request, data can be broadcast to the whole NVL domain in one shot at line rate. Note, however, that this is not officially supported according to the CUDA doc. Those behavior are demonstrated through three unit tests at `tests/cpp/test_multidevice_tma.cpp`. The tests reuse the `SymmetricTensor` abstraction for VMM allocation, IPC handle exchange, and multicast setup, keeping the test bodies focused on the TMA transfer itself. ## Why The CUDA backend for multi-device communication (`csrc/multidevice/cuda_p2p.cpp`) currently uses SM-based copies (regular threads load/store or `multimem.st`) and copy-engine copies (`cudaMemcpyAsync` / `cudaMemcpyBatchAsync`). TMA offers a third transport option that is GPU-initiated, lightweight (single-thread issue), fully asynchronous, and frees SM resources for overlapping compute. This transport is leveraged by DeepEP for intra-node MoE dispatch. This PR validates that TMA works correctly on the memory types used by nvFuser's multi-device infrastructure. This lays the groundwork for a follow-up PR that integrates TMA as a transport option for P2P and multicast communications alongside the existing SM-based copies and copy-engine transports. ## How - The kernel is implemented in `csrc/multidevice/tma_copy.cu`. It is a single-warp kernel where thread 0 performs a two-phase TMA transfer through shared memory (`GMEM(src) --[TMA load]--> SMEM --[TMA store]--> GMEM(dst)`), using `mbarrier` for async completion tracking. TMA is a GMEM-SMEM engine — there is no GMEM-to-GMEM variant, so shared memory staging is inherent to the hardware. - The kernel is compiled at runtime via NVRTC (same pattern as the existing `alltoallv.cu`, `multicast.cu` kernels in `cuda_p2p.cpp`, and other kernels in `runtime/`) and stringified at build time through the existing `NVFUSER_RUNTIME_FILES` pipeline.
1 parent fe948b6 commit 617fa07

File tree

4 files changed

+378
-1
lines changed

4 files changed

+378
-1
lines changed

CMakeLists.txt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -998,6 +998,7 @@ if(BUILD_TEST)
998998
${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir.cpp
999999
${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir_overlap.cpp
10001000
${NVFUSER_ROOT}/tests/cpp/test_multidevice_ipc.cpp
1001+
${NVFUSER_ROOT}/tests/cpp/test_multidevice_tma.cpp
10011002
${NVFUSER_ROOT}/tests/cpp/test_multidevice_lower_communication.cpp
10021003
${NVFUSER_ROOT}/tests/cpp/test_multidevice_lower_communication_cuda.cpp
10031004
${NVFUSER_ROOT}/tests/cpp/test_multidevice_matmul.cpp
@@ -1008,6 +1009,9 @@ if(BUILD_TEST)
10081009
${NVFUSER_ROOT}/tests/cpp/test_multidevice_transformer.cpp
10091010
)
10101011
add_test_without_main(test_multidevice "${MULTIDEVICE_TEST_SRCS}" "")
1012+
target_include_directories(test_multidevice PRIVATE
1013+
"${CMAKE_BINARY_DIR}/include")
1014+
add_dependencies(test_multidevice nvfuser_rt_tma_copy)
10111015
list(APPEND TEST_BINARIES test_multidevice)
10121016
10131017
set(MULTIDEVICE_TUTORIAL_SRCS)
@@ -1239,7 +1243,8 @@ list(APPEND NVFUSER_RUNTIME_FILES
12391243
${NVFUSER_ROOT}/runtime/mbarrier.cu
12401244
${NVFUSER_ROOT}/runtime/memory.cu
12411245
${NVFUSER_ROOT}/runtime/multicast.cu
1242-
${NVFUSER_SRCS_DIR}/multidevice/alltoallv.cu
1246+
${NVFUSER_ROOT}/runtime/alltoallv.cu
1247+
${NVFUSER_ROOT}/runtime/tma_copy.cu
12431248
${NVFUSER_ROOT}/runtime/random_numbers.cu
12441249
${NVFUSER_ROOT}/runtime/tensor_memory.cu
12451250
${NVFUSER_ROOT}/runtime/tensor.cu

runtime/tma_copy.cu

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
// clang-format off
2+
/*
3+
* SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
4+
* All rights reserved.
5+
* SPDX-License-Identifier: BSD-3-Clause
6+
*/
7+
// clang-format on
8+
//
9+
// TMA 1D bulk copy kernel (SM90+ / Hopper).
10+
//
11+
// This file implements a TMA-based data copy kernel. The build system
12+
// stringifies it into nvfuser_resources/tma_copy.h (a const char*),
13+
// which is compiled at runtime via NVRTC. The file is never compiled
14+
// statically by nvcc.
15+
//
16+
// Currently used by tests (test_multidevice_tma.cpp). In a future PR
17+
// this kernel will be integrated as a P2P and multicast transport
18+
// alongside the existing SM-based and copy-engine transports in
19+
// csrc/multidevice/cuda_p2p.{h,cpp}.
20+
//
21+
// TMA (cp.async.bulk) is a GMEM<->SMEM transfer engine — there is no
22+
// GMEM-to-GMEM variant. Shared memory staging is inherent to the
23+
// hardware, so the kernel performs a two-phase copy:
24+
//
25+
// GMEM(src) --[TMA load]--> SMEM --[TMA store]--> GMEM(dst)
26+
//
27+
// A single elected thread (thread 0) drives both phases:
28+
// 1. mbarrier.init (arrival count = 1)
29+
// 2. mbarrier.arrive.expect_tx (announce expected bytes)
30+
// 3. cp.async.bulk.shared::cluster.global (TMA load, async)
31+
// 4. mbarrier.try_wait.parity (block until load completes)
32+
// 5. cp.async.bulk.global.shared::cta (TMA store)
33+
// 6. cp.async.bulk.commit_group + wait_group.read 0
34+
//
35+
// Dynamic shared memory layout (128-byte aligned):
36+
// [0, num_bytes) : staging buffer
37+
// [num_bytes, num_bytes+8) : mbarrier (uint64_t)
38+
39+
extern "C" __global__ void __launch_bounds__(32, 1) tma_copy_1d(
40+
void* __restrict__ dst,
41+
const void* __restrict__ src,
42+
int num_bytes) {
43+
extern __shared__ __align__(128) unsigned char smem[];
44+
45+
unsigned long long* mbar =
46+
reinterpret_cast<unsigned long long*>(smem + num_bytes);
47+
unsigned int smem_addr =
48+
static_cast<unsigned int>(__cvta_generic_to_shared(smem));
49+
unsigned int mbar_addr =
50+
static_cast<unsigned int>(__cvta_generic_to_shared(mbar));
51+
52+
if (threadIdx.x == 0) {
53+
asm volatile(
54+
"mbarrier.init.shared::cta.b64 [%0], %1;" ::"r"(mbar_addr), "r"(1));
55+
asm volatile("fence.mbarrier_init.release.cluster;" :::);
56+
}
57+
__syncwarp();
58+
59+
if (threadIdx.x == 0) {
60+
// Announce expected transaction bytes on the mbarrier
61+
asm volatile(
62+
"mbarrier.arrive.expect_tx.shared::cta.b64 _, [%0], %1;" ::"r"(
63+
mbar_addr),
64+
"r"(num_bytes));
65+
66+
// TMA Load: GMEM -> SMEM (async, completed via mbarrier)
67+
asm volatile(
68+
"cp.async.bulk.shared::cluster.global"
69+
".mbarrier::complete_tx::bytes"
70+
" [%0], [%1], %2, [%3];\n" ::"r"(smem_addr),
71+
"l"(src),
72+
"r"(num_bytes),
73+
"r"(mbar_addr)
74+
: "memory");
75+
76+
// Block until the mbarrier phase flips (TMA load completed)
77+
asm volatile(
78+
"{\n"
79+
".reg .pred P1;\n"
80+
"TMA_COPY_WAIT_LOAD:\n"
81+
"mbarrier.try_wait.parity.shared::cta.b64"
82+
" P1, [%0], %1;\n"
83+
"@P1 bra TMA_COPY_LOAD_DONE;\n"
84+
"bra TMA_COPY_WAIT_LOAD;\n"
85+
"TMA_COPY_LOAD_DONE:\n"
86+
"}" ::"r"(mbar_addr),
87+
"r"(0));
88+
89+
// TMA Store: SMEM -> GMEM
90+
asm volatile(
91+
"cp.async.bulk.global.shared::cta.bulk_group"
92+
" [%0], [%1], %2;\n" ::"l"(dst),
93+
"r"(smem_addr),
94+
"r"(num_bytes)
95+
: "memory");
96+
asm volatile("cp.async.bulk.commit_group;");
97+
asm volatile("cp.async.bulk.wait_group.read 0;" ::: "memory");
98+
99+
asm volatile("mbarrier.inval.shared::cta.b64 [%0];" ::"r"(mbar_addr));
100+
}
101+
}

0 commit comments

Comments
 (0)