-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Description
Which component has the problem?
CUTLASS C++
Bug Report
Summary
TMA descriptors in the Params struct of sm120_mma_tma_blockwise_scaling.hpp and sm90_epilogue_tma_warpspecialized.hpp are not 64-byte aligned, causing prefetch.tensormap to fail with "misaligned address" error on SM120 (RTX 5090).
Environment
- GPU: NVIDIA GeForce RTX 5090 (SM 12.0)
- Driver: 591.44
- CUDA Toolkit: 13.1
- CUTLASS: 4.3.3 (commit d55f6be)
- OS: Windows 11
- Location: Bare-metal
Steps to Reproduce
- Build CUTLASS example 87a with
-arch=sm_120a:
nvcc -std=c++17 -O3 -arch=sm_120a --expt-relaxed-constexpr -I cutlass/include -I cutlass/tools/util/include -I cutlass/examples/common cutlass/examples/87_blackwell_geforce_gemm_blockwise/87a_blackwell_geforce_fp8_bf16_gemm_blockwise.cu -o example87a.exe- Run the example:
./example87a.exe- Observe:
misaligned addresserror
Root Cause Analysis
Using printf debugging, we traced the crash to prefetch_tma_descriptors():
// sm120_mma_tma_blockwise_scaling.hpp:359
static void prefetch_tma_descriptors(Params const& mainloop_params) {
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
// ^^^ CRASH HERE
}Printing TMA descriptor addresses revealed misalignment:
[TMA] desc_a=0000000200610398 (mod64=24) // NOT 64-byte aligned\!
[TMA] desc_b=0000000200610420 (mod64=32) // NOT 64-byte aligned\!
The prefetch.tensormap PTX instruction requires 64-byte alignment:
prefetch.tensormap [%0]; // requires 64-byte aligned address
The Params struct contains TMA_A tma_load_a and TMA_B tma_load_b members without alignas(64), causing the embedded TMA descriptors to be misaligned.
Expected Behavior
TMA descriptor members in Params struct should be 64-byte aligned to satisfy prefetch.tensormap requirements.
Affected Files
cutlass/gemm/collective/sm120_mma_tma_blockwise_scaling.hpp(line ~254)cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp(line ~244)
Note: The SM90 epilogue file is also affected, which may impact SM90/SM100 kernels as well, though we only tested on SM120.
Appendix: Proposed Fix
Adding alignas(64) to the Params struct and TMA members resolves the issue:
Fix for sm120_mma_tma_blockwise_scaling.hpp:
// Device side kernel params
struct alignas(64) Params { // Add alignas(64) to struct
using TMA_A = /* ... */;
using TMA_B = /* ... */;
alignas(64) TMA_A tma_load_a; // Add alignas(64) to member
alignas(64) TMA_B tma_load_b; // Add alignas(64) to member
// ...
};Fix for sm90_epilogue_tma_warpspecialized.hpp:
struct alignas(64) Params { // Add alignas(64) to struct
using TMA_C = /* ... */;
using TMA_D = /* ... */;
typename FusionCallbacks::Params thread{};
alignas(64) TMA_C tma_load_c; // Add alignas(64) to member
alignas(64) TMA_D tma_store_d; // Add alignas(64) to member
// ...
};Results After Fix
[TMA] desc_a=00000019006103C0 (mod64=0) // ALIGNED\!
[TMA] desc_b=0000001900610480 (mod64=0) // ALIGNED\!
[KERNEL] After mainloop prefetch, before epilogue prefetch
[KERNEL] After epilogue prefetch_tma_descriptors
[FP8 GEMM SM120] GEMM completed OK
The kernel runs without crashing after the alignment fix.
Notes on Pull Request
We have not submitted a PR because:
-
SM90 impact unknown: The epilogue fix touches
sm90_epilogue_tma_warpspecialized.hpp, which is used by SM90/SM100 kernels. We do not have access to H100/B100 to verify this does not break existing functionality. -
Separate correctness issue: After fixing alignment, the kernel runs but produces ~12% relative error. This is likely a separate scale factor issue, not related to alignment.
We are happy to submit a PR if the CUTLASS team confirms the SM90 change is safe.
Related Issues
- [QST][CUTEDSL] Address Misalignment in FP8 Gemm #2902 - Different alignment issue (
partition_S()dropping LDSM alignment) - SM 120 (Blackwell GeForce RTX 50 Series) Block-Scaled MMA Runtime Assertion Failure #2820 - Arch conditional MMA assertion (solved by
-arch=sm_120a)