Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,7 @@ if(BUILD_TEST)
${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir_overlap.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_ipc.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_tma.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_lower_communication.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_lower_communication_cuda.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_matmul.cpp
Expand All @@ -1008,6 +1009,9 @@ if(BUILD_TEST)
${NVFUSER_ROOT}/tests/cpp/test_multidevice_transformer.cpp
)
add_test_without_main(test_multidevice "${MULTIDEVICE_TEST_SRCS}" "")
target_include_directories(test_multidevice PRIVATE
"${CMAKE_BINARY_DIR}/include")
add_dependencies(test_multidevice nvfuser_rt_tma_copy)
list(APPEND TEST_BINARIES test_multidevice)

set(MULTIDEVICE_TUTORIAL_SRCS)
Expand Down Expand Up @@ -1239,7 +1243,8 @@ list(APPEND NVFUSER_RUNTIME_FILES
${NVFUSER_ROOT}/runtime/mbarrier.cu
${NVFUSER_ROOT}/runtime/memory.cu
${NVFUSER_ROOT}/runtime/multicast.cu
${NVFUSER_SRCS_DIR}/multidevice/alltoallv.cu
${NVFUSER_ROOT}/runtime/alltoallv.cu
${NVFUSER_ROOT}/runtime/tma_copy.cu
${NVFUSER_ROOT}/runtime/random_numbers.cu
${NVFUSER_ROOT}/runtime/tensor_memory.cu
${NVFUSER_ROOT}/runtime/tensor.cu
Expand Down
File renamed without changes.
101 changes: 101 additions & 0 deletions runtime/tma_copy.cu
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@naoyam organization-wise, do you prefer to move this (and alltoallv.cu) to runtime/tma_copy.cu?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, since that directory is the one where we hold all runtime code.

Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
//
// TMA 1D bulk copy kernel (SM90+ / Hopper).
//
// This file implements a TMA-based data copy kernel. The build system
// stringifies it into nvfuser_resources/tma_copy.h (a const char*),
// which is compiled at runtime via NVRTC. The file is never compiled
// statically by nvcc.
//
// Currently used by tests (test_multidevice_tma.cpp). In a future PR
// this kernel will be integrated as a P2P and multicast transport
// alongside the existing SM-based and copy-engine transports in
// csrc/multidevice/cuda_p2p.{h,cpp}.
//
// TMA (cp.async.bulk) is a GMEM<->SMEM transfer engine — there is no
// GMEM-to-GMEM variant. Shared memory staging is inherent to the
// hardware, so the kernel performs a two-phase copy:
//
// GMEM(src) --[TMA load]--> SMEM --[TMA store]--> GMEM(dst)
//
// A single elected thread (thread 0) drives both phases:
// 1. mbarrier.init (arrival count = 1)
// 2. mbarrier.arrive.expect_tx (announce expected bytes)
// 3. cp.async.bulk.shared::cluster.global (TMA load, async)
// 4. mbarrier.try_wait.parity (block until load completes)
// 5. cp.async.bulk.global.shared::cta (TMA store)
// 6. cp.async.bulk.commit_group + wait_group.read 0
//
// Dynamic shared memory layout (128-byte aligned):
// [0, num_bytes) : staging buffer
// [num_bytes, num_bytes+8) : mbarrier (uint64_t)

extern "C" __global__ void __launch_bounds__(32, 1) tma_copy_1d(
void* __restrict__ dst,
const void* __restrict__ src,
int num_bytes) {
extern __shared__ __align__(128) unsigned char smem[];

unsigned long long* mbar =
reinterpret_cast<unsigned long long*>(smem + num_bytes);
unsigned int smem_addr =
static_cast<unsigned int>(__cvta_generic_to_shared(smem));
unsigned int mbar_addr =
static_cast<unsigned int>(__cvta_generic_to_shared(mbar));

if (threadIdx.x == 0) {
asm volatile(
"mbarrier.init.shared::cta.b64 [%0], %1;" ::"r"(mbar_addr), "r"(1));
asm volatile("fence.mbarrier_init.release.cluster;" :::);
}
__syncwarp();

if (threadIdx.x == 0) {
// Announce expected transaction bytes on the mbarrier
asm volatile(
"mbarrier.arrive.expect_tx.shared::cta.b64 _, [%0], %1;" ::"r"(
mbar_addr),
"r"(num_bytes));

// TMA Load: GMEM -> SMEM (async, completed via mbarrier)
asm volatile(
"cp.async.bulk.shared::cluster.global"
".mbarrier::complete_tx::bytes"
" [%0], [%1], %2, [%3];\n" ::"r"(smem_addr),
"l"(src),
"r"(num_bytes),
"r"(mbar_addr)
: "memory");

// Block until the mbarrier phase flips (TMA load completed)
asm volatile(
"{\n"
".reg .pred P1;\n"
"TMA_COPY_WAIT_LOAD:\n"
"mbarrier.try_wait.parity.shared::cta.b64"
" P1, [%0], %1;\n"
"@P1 bra TMA_COPY_LOAD_DONE;\n"
"bra TMA_COPY_WAIT_LOAD;\n"
"TMA_COPY_LOAD_DONE:\n"
"}" ::"r"(mbar_addr),
"r"(0));

// TMA Store: SMEM -> GMEM
asm volatile(
"cp.async.bulk.global.shared::cta.bulk_group"
" [%0], [%1], %2;\n" ::"l"(dst),
"r"(smem_addr),
"r"(num_bytes)
: "memory");
asm volatile("cp.async.bulk.commit_group;");
asm volatile("cp.async.bulk.wait_group.read 0;" ::: "memory");

asm volatile("mbarrier.inval.shared::cta.b64 [%0];" ::"r"(mbar_addr));
}
}
Loading
Loading