Skip to content

[BUG] SM120 NVF4 GEMM (example 79a): misaligned address crash #2906

@m96-chan

Description

@m96-chan

Which component has the problem?

CUTLASS C++

Bug Report

Summary

Example 79a (79a_blackwell_geforce_nvfp4_bf16_gemm.cu) crashes with "misaligned address" on RTX 5090 (SM120). This is related to #2905 but affects a different mainloop (sm120_blockscaled_mma_tma.hpp for NVF4 instead of sm120_mma_tma_blockwise_scaling.hpp for FP8).

Environment

  • GPU: NVIDIA GeForce RTX 5090 (SM 12.0)
  • CUDA Toolkit: 13.1
  • CUTLASS: v4.3.4
  • OS: Windows 11

Steps to Reproduce

  1. Build example 79a:
cd examples/79_blackwell_geforce_gemm
nvcc -O2 -arch=sm_120a -std=c++17 --expt-relaxed-constexpr \
  -I ../../include -I ../../tools/util/include \
  -DCUTLASS_ARCH_MMA_SM120_SUPPORTED=1 \
  79a_blackwell_geforce_nvfp4_bf16_gemm.cu -o 79a_test
  1. Run:
./79a_test --m=256 --n=256 --k=256
  1. Result:
========= COMPUTE-SANITIZER
========= Misaligned shared or local address
=========     at ...MainloopSm120TmaWarpSpecializedBlockScaled...

Root Cause Analysis

Two alignment issues:

  1. TMA descriptor alignment: prefetch.tensormap requires 64-byte alignment. Params structs lack alignas(64).
  2. Scale factor smem alignment: ldmatrix.sync.aligned requires 16-byte alignment. Comparing 79a vs 87a mainloops:
// sm120_mma_tma_blockwise_scaling.hpp (87a) - uses array_aligned (16-byte default)
cute::array_aligned<ElementSF, ...> smem_scale_A;

// sm120_blockscaled_mma_tma.hpp (79a) - no alignment
cute::ArrayEngine<ElementSF, ...> smem_SFA;  // Missing alignment
cute::ArrayEngine<ElementSF, ...> smem_SFB;

Expected Behavior

Disposition: Passed

Affected Files

  1. include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp
  2. include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp

Appendix: Proposed Fix

sm120_blockscaled_mma_tma.hpp:

 struct TensorStorage : cute::aligned_struct<128, _0> {
   alignas(1024) cute::ArrayEngine<SmemAllocTypeA, ...> smem_A;
   alignas(1024) cute::ArrayEngine<SmemAllocTypeB, ...> smem_B;
-  cute::ArrayEngine<ElementSF, ...> smem_SFA;
-  cute::ArrayEngine<ElementSF, ...> smem_SFB;
+  alignas(128) cute::ArrayEngine<ElementSF, ...> smem_SFA;
+  alignas(128) cute::ArrayEngine<ElementSF, ...> smem_SFB;
 };

-struct Params {
+struct alignas(64) Params {
-  TMA_A tma_load_a;
-  TMA_B tma_load_b;
-  TMA_SFA tma_load_sfa;
-  TMA_SFB tma_load_sfb;
+  alignas(64) TMA_A tma_load_a;
+  alignas(64) TMA_B tma_load_b;
+  alignas(64) TMA_SFA tma_load_sfa;
+  alignas(64) TMA_SFB tma_load_sfb;
   // ...
 };

sm90_epilogue_tma_warpspecialized.hpp:

-struct Params {
+struct alignas(64) Params {
   // ...
-  TMA_C tma_load_c;
-  TMA_D tma_store_d;
+  alignas(64) TMA_C tma_load_c;
+  alignas(64) TMA_D tma_store_d;
 };

After fix: Passed, 3043 GFLOPS

Notes on Pull Request

The fix touches sm90_epilogue_tma_warpspecialized.hpp which is shared across SM90/SM100/SM120. I only have SM120 hardware, so I cannot verify impact on other architectures. Happy to submit a PR if the team would like.

Related Issues

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions