diff --git a/csrc/flashinfer_xqa_ops.cu b/csrc/flashinfer_xqa_ops.cu index 87a614d778..8aff9c007d 100644 --- a/csrc/flashinfer_xqa_ops.cu +++ b/csrc/flashinfer_xqa_ops.cu @@ -16,8 +16,8 @@ #include "pytorch_extension_utils.h" -void xqa_wrapper(int64_t multiProcessorCount, int64_t nbKHeads, int64_t slidingWinSize, - double qScale, at::Tensor output, +void xqa_wrapper(bool run_fp8_mha, int64_t multiProcessorCount, int64_t nbKHeads, + int64_t slidingWinSize, double qScale, at::Tensor output, #if LOW_PREC_OUTPUT at::Tensor rcpOutScale, #endif diff --git a/csrc/xqa/gmma.cuh b/csrc/xqa/gmma.cuh new file mode 100644 index 0000000000..d1b2547fcd --- /dev/null +++ b/csrc/xqa/gmma.cuh @@ -0,0 +1,145 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once +#include "cuda_hint.cuh" +#include "mha_stdheaders.cuh" +#include "utils.cuh" +#ifndef __CUDACC__ +#include +#endif +#include +#include + +namespace gmma { + +enum class SwizzleMode : uint64_t { kNONE = 0, k128 = 1, k64 = 2, k32 = 3 }; + +struct MatDesc { + uint64_t addr : 16; + uint64_t dimKOffset : 16; + uint64_t dimMNOffset : 16; + uint64_t pad0 : 1; + uint64_t baseOffset : 3; + uint64_t pad1 : 10; + SwizzleMode swizzle : 2; + + enum class Raw : uint64_t {}; + + [[nodiscard]] __device__ inline MatDesc withAddr(void const* data) const { + MatDesc ret = *this; + ret.addr = encode(__cvta_generic_to_shared(data)); + return ret; + } + + static __device__ inline uint32_t encode(uint32_t val) { return (val & 0x3FFFFU) >> 4; } + + __device__ inline bool operator==(MatDesc const& other) const { return raw() == other.raw(); } + + __device__ inline Raw const& raw() const { + static_assert(sizeof(MatDesc) == 8); + return reinterpret_cast(*this); + } + + static __device__ inline MatDesc fromRaw(Raw const& raw) { + return reinterpret_cast(raw); + } +}; + +static_assert(sizeof(MatDesc) == 8); + +[[nodiscard]] __device__ inline MatDesc::Raw addAddr(MatDesc::Raw base, void const* data) { + assert((uint32_t(__cvta_generic_to_shared(data)) & ~0x3FFFFU) == 0); + MatDesc::Raw ret = base; + auto& u32x2 = reinterpret_cast(ret); + u32x2[0] += static_cast(__cvta_generic_to_shared(data)) >> 4; + return ret; +} + +__device__ inline MatDesc makeMatDesc(void const* data, uint32_t dimKByteOffset, + uint32_t dimMNByteOffset, void const* patternStartAddr, + SwizzleMode swizzleMode) { + uint32_t const patternAddr = __cvta_generic_to_shared(patternStartAddr); + uint32_t const baseAlign = [&]() -> uint32_t { + switch (swizzleMode) { + case SwizzleMode::kNONE: + return 1; + case SwizzleMode::k128: + return 1024; + case SwizzleMode::k64: + return 512; + case SwizzleMode::k32: + return 256; + } + asm volatile("trap;\n"); + return 0; + }(); + uint32_t const baseOffset = ((patternAddr % baseAlign == 0) ? 0U : ((patternAddr >> 0x7) & 0x7)); + return MatDesc{ + /*addr=*/MatDesc::encode(__cvta_generic_to_shared(data)), + /*dimKOffset=*/MatDesc::encode(dimKByteOffset), + /*dimMNOffset=*/MatDesc::encode(dimMNByteOffset), + /*pad0=*/0, + /*baseOffset=*/baseOffset, + /*pad1=*/0, + /*swizzle=*/swizzleMode, + }; +} + +__device__ inline MatDesc makeMatDesc(void const* data, uint32_t dimKByteOffset, + uint32_t dimMNByteOffset, SwizzleMode swizzleMode) { + return makeMatDesc(data, dimKByteOffset, dimMNByteOffset, data, swizzleMode); +} + +inline constexpr uint32_t instM = 64; + +template +inline constexpr uint32_t instK = 32 / sizeof(MathElem); + +inline constexpr uint32_t instNBase = 8; + +// for both a and b, outer-dim is gemm-K and inner-dim is gemm-M or gemm-N +// acc is used as both input and output. +template +__device__ void mma_async_shmA(float (&acc)[exactDiv(n, instNBase)][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal); +template +__device__ void mma_async_regA(float (&acc)[exactDiv(n, instNBase)][2][2], + uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal); + +__device__ inline void fence() { asm volatile("wgmma.fence.sync.aligned;\n"); } + +__device__ inline void commit_group() { asm volatile("wgmma.commit_group.sync.aligned;\n"); } + +template +__device__ inline void wait_group() { + asm volatile("wgmma.wait_group.sync.aligned %0\n; " ::"n"(targetNbInFlightGroups)); +} + +template +constexpr SwizzleMode getSwizzleMode(Array2D const&) { + constexpr auto rowBytes = Array2D::rowBytes; + if constexpr (!swizzle) { + return SwizzleMode::kNONE; + } + if constexpr (rowBytes % 128 == 0) { + return SwizzleMode::k128; + } else if constexpr (rowBytes == 64) { + return SwizzleMode::k64; + } else { + static_assert(rowBytes == 32); + return SwizzleMode::k32; + } +} +} // namespace gmma + +#include "gmma_impl.cuh" diff --git a/csrc/xqa/gmma_impl.cuh b/csrc/xqa/gmma_impl.cuh new file mode 100644 index 0000000000..b9515ddea9 --- /dev/null +++ b/csrc/xqa/gmma_impl.cuh @@ -0,0 +1,4971 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once +#include "cuda_hint.cuh" +#include "mha_stdheaders.cuh" +#ifndef __CUDACC__ +#include +#endif +#include +#include + +namespace gmma { +// cog template. Do code generation with: pip install cogapp; cog -r $filename + +// clang-format off +/*[[[cog +import cog +reg_list = lambda beg,end: ", ".join([f"%{i}" for i in range(beg, end)]) +acc_placeholder = lambda n: "{%s}" % reg_list(0, n//2) +acc_registers = lambda n: "\n , ".join([f'"+f"(acc[{i}][0][0]), "+f"(acc[{i}][0][1]), "+f"(acc[{i}][1][0]), "+f"(acc[{i}][1][1])' for i in range(n//8)]) +ptx_eol = "\\n" +n_list = [8, 16, 24, 32, 64, 128, 256] +for n in n_list: + cog.outl(f''' +template<> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, {n}, false, false>(float(&acc)[{n//8}][2][2], MatDesc::Raw descA, +MatDesc::Raw descB, bool accHasVal) +{{ + if (accHasVal) {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k32.f32.e4m3.e4m3{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "%{n//2},{ptx_eol}" //a-desc + "%{n//2+1},{ptx_eol}" //b-desc + "%{n//2+2}, 1, 1;{ptx_eol}" + : {acc_registers(n)} + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + }} + else {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k32.f32.e4m3.e4m3{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "%{n//2},{ptx_eol}" //a-desc + "%{n//2+1},{ptx_eol}" //b-desc + "%{n//2+2}, 1, 1;{ptx_eol}" + : {acc_registers(n)} + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + }} +}} + +template<> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, {n}, false, false>(float(&acc)[{n//8}][2][2], uint32_t +const(&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{{ + if (accHasVal) {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k32.f32.e4m3.e4m3{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "{{{reg_list(n//2, n//2 + 4)}}},{ptx_eol}" //a + "%{n//2+4},{ptx_eol}" //b-desc + "%{n//2+5}, 1, 1;{ptx_eol}" + : {acc_registers(n)} + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), "l"(reinterpret_cast(descB)), "n"(true)); + }} + else {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k32.f32.e4m3.e4m3{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "{{{reg_list(n//2, n//2 + 4)}}},{ptx_eol}" //a + "%{n//2+4},{ptx_eol}" //b-desc + "%{n//2+5}, 1, 1;{ptx_eol}" + : {acc_registers(n)} + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), "l"(reinterpret_cast(descB)), "n"(false)); + }} +}} +''') + +for n in n_list: + for transA in [0, 1]: + for transB in [0, 1]: + for t,s in [('half', 'f16'), ('__nv_bfloat16', 'bf16')]: + cog.outl(f''' +template<> +__device__ inline void mma_async_shmA<{t}, {n}, {transA}, {transB}>(float(&acc)[{n//8}][2][2], MatDesc::Raw descA, +MatDesc::Raw descB, bool accHasVal) +{{ + if (accHasVal) {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k16.f32.{s}.{s}{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "%{n//2},{ptx_eol}" //a-desc + "%{n//2+1},{ptx_eol}" //b-desc + "%{n//2+2}, 1, 1, {transA}, {transB};{ptx_eol}" + : {acc_registers(n)} + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + }} + else {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k16.f32.{s}.{s}{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "%{n//2},{ptx_eol}" //a-desc + "%{n//2+1},{ptx_eol}" //b-desc + "%{n//2+2}, 1, 1, {transA}, {transB};{ptx_eol}" + : {acc_registers(n)} + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + }} +}} +''') + if transA == 0: + cog.outl(f''' +template<> +__device__ inline void mma_async_regA<{t}, {n}, {transA}, {transB}>(float(&acc)[{n//8}][2][2], uint32_t +const(&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{{ + if (accHasVal) {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k16.f32.{s}.{s}{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "{{{reg_list(n//2, n//2 + 4)}}},{ptx_eol}" //a + "%{n//2+4},{ptx_eol}" //b-desc + "%{n//2+5}, 1, 1, {transB};{ptx_eol}" + : {acc_registers(n)} + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), "l"(reinterpret_cast(descB)), "n"(true)); + }} + else {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k16.f32.{s}.{s}{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "{{{reg_list(n//2, n//2 + 4)}}},{ptx_eol}" //a + "%{n//2+4},{ptx_eol}" //b-desc + "%{n//2+5}, 1, 1, {transB};{ptx_eol}" + : {acc_registers(n)} + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), "l"(reinterpret_cast(descB)), "n"(false)); + }} +}} +''') +]]]*/ +// clang-format on + +template <> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 8, false, false>(float (&acc)[1][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 8, false, false>(float (&acc)[1][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 16, false, false>(float (&acc)[2][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 16, false, false>(float (&acc)[2][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 24, false, false>(float (&acc)[3][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 24, false, false>(float (&acc)[3][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 32, false, false>(float (&acc)[4][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 32, false, false>(float (&acc)[4][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 64, false, false>(float (&acc)[8][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 64, false, false>(float (&acc)[8][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 128, false, false>(float (&acc)[16][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 128, false, false>( + float (&acc)[16][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 256, false, false>(float (&acc)[32][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 256, false, false>( + float (&acc)[32][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[1][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA(float (&acc)[1][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 8, 0, 0>(float (&acc)[1][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 8, 0, 0>(float (&acc)[1][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[1][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA(float (&acc)[1][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 8, 0, 1>(float (&acc)[1][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 8, 0, 1>(float (&acc)[1][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[1][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 8, 1, 0>(float (&acc)[1][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[1][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 8, 1, 1>(float (&acc)[1][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[2][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA(float (&acc)[2][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 16, 0, 0>(float (&acc)[2][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 16, 0, 0>(float (&acc)[2][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[2][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA(float (&acc)[2][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 16, 0, 1>(float (&acc)[2][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 16, 0, 1>(float (&acc)[2][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[2][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 16, 1, 0>(float (&acc)[2][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[2][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 16, 1, 1>(float (&acc)[2][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[3][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA(float (&acc)[3][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 0, 0>(float (&acc)[3][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 24, 0, 0>(float (&acc)[3][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[3][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA(float (&acc)[3][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 0, 1>(float (&acc)[3][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 24, 0, 1>(float (&acc)[3][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[3][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 1, 0>(float (&acc)[3][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[3][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 1, 1>(float (&acc)[3][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[4][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA(float (&acc)[4][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 32, 0, 0>(float (&acc)[4][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 32, 0, 0>(float (&acc)[4][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[4][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA(float (&acc)[4][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 32, 0, 1>(float (&acc)[4][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 32, 0, 1>(float (&acc)[4][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[4][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 32, 1, 0>(float (&acc)[4][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[4][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 32, 1, 1>(float (&acc)[4][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[8][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA(float (&acc)[8][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 64, 0, 0>(float (&acc)[8][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 64, 0, 0>(float (&acc)[8][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[8][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA(float (&acc)[8][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 64, 0, 1>(float (&acc)[8][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 64, 0, 1>(float (&acc)[8][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[8][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 64, 1, 0>(float (&acc)[8][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[8][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 64, 1, 1>(float (&acc)[8][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[16][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA(float (&acc)[16][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 128, 0, 0>(float (&acc)[16][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 128, 0, 0>(float (&acc)[16][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[16][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA(float (&acc)[16][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 128, 0, 1>(float (&acc)[16][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 128, 0, 1>(float (&acc)[16][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[16][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 128, 1, 0>(float (&acc)[16][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[16][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 128, 1, 1>(float (&acc)[16][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[32][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA(float (&acc)[32][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 256, 0, 0>(float (&acc)[32][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 256, 0, 0>(float (&acc)[32][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[32][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA(float (&acc)[32][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 256, 0, 1>(float (&acc)[32][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 256, 0, 1>(float (&acc)[32][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[32][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 256, 1, 0>(float (&acc)[32][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA(float (&acc)[32][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 256, 1, 1>(float (&acc)[32][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +//[[[end]]] +} // namespace gmma diff --git a/csrc/xqa/mha.cu b/csrc/xqa/mha.cu index c896017780..a951e27dca 100644 --- a/csrc/xqa/mha.cu +++ b/csrc/xqa/mha.cu @@ -476,7 +476,7 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy col + actualQSeqLen < nbValidCols ? true : packedMask & (1u << ((col + actualQSeqLen - nbValidCols) - maskPosStart)); - acc(m, n)(i, j) = maskFlag && col < nbValidCols ? acc(m, n)(i, j) : -INFINITY; + acc(m, n)(i, j) = maskFlag && col < nbValidCols ? acc(m, n)(i, j) : safeInitRowMax; } } } @@ -2709,11 +2709,7 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32 #if SPEC_DEC mask, #endif - attentionSinks, cacheList, -#if BEAM_WIDTH > 1 - beamSearchParams, -#endif - batchSize, kvCacheScale, semaphores, scratch); + attentionSinks, cacheList, batchSize, kvCacheScale, semaphores, scratch); #else KVCacheList const cacheList{kvCacheData, seqLen, maxSeqLen}; #ifndef NDEBUG diff --git a/csrc/xqa/mha.h b/csrc/xqa/mha.h index 77d8a2fd2f..171524f0b1 100644 --- a/csrc/xqa/mha.h +++ b/csrc/xqa/mha.h @@ -186,6 +186,20 @@ void launchHopperF8MHA( #endif uint32_t* semaphores, void* scratch, cudaStream_t stream); +void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, + uint32_t slidingWinSize, float qScale, OutputHead* output, +#if LOW_PREC_OUTPUT + float const* rcpOutScale, +#endif + InputHead const* q, float const* attentionSinks, + GMemCacheHead* pool, KVCachePageIndex const* kvCachePageList, + uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, + float const* __restrict__ kvCacheScale, +#if SPEC_DEC + uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask, +#endif + uint32_t* semaphores, void* scratch, cudaStream_t stream); + void launchMLA( cudaDeviceProp const& prop, uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed diff --git a/csrc/xqa/mha_sm90.cu b/csrc/xqa/mha_sm90.cu new file mode 100644 index 0000000000..286ee08ec5 --- /dev/null +++ b/csrc/xqa/mha_sm90.cu @@ -0,0 +1,3271 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include "cuda_hint.cuh" +#include "defines.h" +#if !(IS_MLA) +#include "barriers.cuh" +#include "utils.cuh" +#include "utils.h" + +#if SPEC_DEC +#define Q_HEADS_PER_CTA 64 +#include "specDec.h" +#endif + +#ifndef GENERATE_CUBIN +#include + +#include "hostUtils.h" +#include "tensorMap.h" +#endif +#include "gmma.cuh" +#include "mha.h" +#include "mhaUtils.cuh" +#include "mha_stdheaders.cuh" +#include "tma.h" + +#define DBG_PRINT 0 + +#ifdef SPEC_Q_SEQ_LEN +static_assert(SPEC_DEC, "SPEC_Q_SEQ_LEN is only supported for SPEC_DEC"); +constexpr uint32_t specDecQLen = SPEC_Q_SEQ_LEN; +static_assert(specDecQLen * headGrpSize <= 32, "SPEC_Q_SEQ_LEN macro value is too large"); +#define SWAP_AB 1 +#else +#define SWAP_AB (!SPEC_DEC) +#endif + +#define IS_SUPPORTED_F16_CASE \ + (CACHE_ELEM_ENUM == 0 && !SPEC_DEC && SWAP_AB && !USE_INPUT_KV && !LOW_PREC_OUTPUT) + +inline constexpr bool swapAB = SWAP_AB; + +#pragma region Config + +static_assert((inputElemSize == cacheElemSize && mha::is_same_v) || + inputElemSize > cacheElemSize); +using MathElem = + mha::conditional_t<(inputElemSize > cacheElemSize && mha::is_same_v), + InputElem, CacheElem>; + +constexpr uint32_t gmmaWarpsPerGrp = 4; +constexpr uint32_t gmmaWarpGrpSize = warp_size * gmmaWarpsPerGrp; +constexpr uint32_t gemm0NbGmmaGrps = 1; +constexpr uint32_t gemm0NbThrds = gmmaWarpGrpSize * gemm0NbGmmaGrps; +constexpr uint32_t gemm0NbWarps = gmmaWarpsPerGrp * gemm0NbGmmaGrps; +#if SPEC_DEC && !SWAP_AB +inline constexpr uint32_t ctaNbQHeads = Q_HEADS_PER_CTA; +inline constexpr uint32_t inputTokensPerCta = ctaNbQHeads / headGrpSize; +constexpr uint32_t ctaNbValidQHeads = ctaNbQHeads; +#elif SPEC_DEC && SWAP_AB +inline constexpr uint32_t inputTokensPerCta = specDecQLen; +inline constexpr uint32_t ctaNbValidQHeads = headGrpSize * inputTokensPerCta; +inline constexpr uint32_t ctaNbQHeads = []() { + static_assert(ctaNbValidQHeads <= 32, "ctaNbValidQHeads cannot exceed 32"); + if constexpr (ctaNbValidQHeads <= 8) { + return 8; + } + if constexpr (ctaNbValidQHeads <= 16) { + return 16; + } + return 32; +}(); +#else +inline constexpr uint32_t ctaNbValidQHeads = headGrpSize * beamWidth; +inline constexpr uint32_t ctaNbQHeads = roundUp(ctaNbValidQHeads, swapAB ? 8U : 64U); +inline constexpr uint32_t inputTokensPerCta = 1; +#endif +constexpr uint32_t gemm0WarpGrpTileNbTokens = 64; +inline constexpr uint32_t gemm0CtaTileNbTokens = gemm0WarpGrpTileNbTokens * gemm0NbGmmaGrps; +constexpr uint32_t gemm1NbGmmaGrps = 1; +constexpr uint32_t gemm1NbThrds = gmmaWarpGrpSize * gemm1NbGmmaGrps; +constexpr uint32_t gemm1NbWarps = gmmaWarpsPerGrp * gemm1NbGmmaGrps; +constexpr uint32_t gemm1CtaTileNbTokens = gemm0CtaTileNbTokens; +constexpr uint32_t mathHeadBytes = sizeof(Vec); +constexpr uint32_t nbIOWarps = 4; +constexpr uint32_t nbIOThrds = warp_size * nbIOWarps; +constexpr uint32_t multiBlockMinNbTilesPerCta = 1; // 3; // @fixme: need tuning +constexpr uint32_t multiBlockMinNbTiles = multiBlockMinNbTilesPerCta * 2; +constexpr uint32_t nbWarps = gemm0NbWarps + gemm1NbWarps + nbIOWarps; + +constexpr uint32_t cacheHeadPartBytes = mha::min(paddedCacheHeadBytes, 128U); +constexpr uint32_t cacheHeadNbParts = + exactDiv(paddedCacheHeadBytes, cacheHeadPartBytes); // @fixme: support divUp in the future +constexpr uint32_t cacheHeadPartElems = exactDiv(headElems, cacheHeadNbParts); +constexpr uint32_t swizzleBytes = cacheHeadPartBytes; +static_assert(swizzleBytes == 128 || swizzleBytes == 64 || swizzleBytes == 32); + +constexpr bool needInputCvt = + inputElemSize > cacheElemSize&& mha::is_same_v; +constexpr bool needCacheCvt = inputElemSize > cacheElemSize&& mha::is_same_v; +static_assert(needInputCvt || needCacheCvt || mha::is_same_v); + +using ShmQWiseVec = Vec; + +constexpr uint32_t qPartBytes = mha::min(mathHeadBytes, 128U); +constexpr uint32_t nbQParts = exactDiv(mathHeadBytes, qPartBytes); +constexpr uint32_t grainsPerQPart = exactDiv(qPartBytes, grainBytes); + +constexpr uint32_t xPartBytes = mha::min(cacheElemSize * gemm0CtaTileNbTokens, 128U); +constexpr uint32_t nbXParts = exactDiv(cacheElemSize * gemm0CtaTileNbTokens, xPartBytes); +constexpr uint32_t grainsPerXPart = exactDiv(xPartBytes, grainBytes); +constexpr uint32_t cacheElemsPerGrain = exactDiv(grainBytes, cacheElemSize); + +constexpr uint32_t grainsPerIOHead = exactDiv(ioHeadBytes, grainBytes); +constexpr uint32_t grainsPerPaddedInputHead = exactDiv(paddedInputHeadBytes, grainBytes); + +#if USE_BEAM_SEARCH +constexpr uint32_t beamSearchGemm0CtaTileNbTokens = exactDiv(gemm0CtaTileNbTokens, beamWidth); +#endif + +using PaddedOutHead = PaddedInputHead; + +#pragma endregion Config + +struct alignas(128) SharedMem { + using KBuffer = Array2D; + static constexpr uint32_t nbKBuf = 2; + KBuffer k[nbKBuf]; // as is loaded from global mem. + using XBuffer = Vec, nbXParts>; + static constexpr uint32_t nbXBuf = + 2 * (gemm0CtaTileNbTokens >= gemm1CtaTileNbTokens + ? 1 + : exactDiv(gemm1CtaTileNbTokens, gemm0CtaTileNbTokens)); + using VBuffer = + Vec, + cacheHeadNbParts>; +#if !SWAP_AB + using VTBuffer = + Array2D; +#endif + static constexpr uint32_t nbVBuf = 2; +#if CACHE_ELEM_ENUM == 0 + using OutSwizzleBuf = Array2D; +#elif CACHE_ELEM_ENUM == 2 + using OutSwizzleBuf = Array2D, 4>, ctaNbQHeads, exactDiv(headElems, 4 * 4)>; +#endif + static_assert(nbXBuf == nbVBuf); + + union ReusedXVOutSwizzleBuf { + struct XV { + XBuffer x; + VBuffer v; +#if !SWAP_AB + VTBuffer vt; +#endif + // @fixme: also put xColMax and xColSum here + } xv; + + OutSwizzleBuf outSwizzle; + } reusedXVOutSwizzleBuf[nbXBuf]; + + static_assert(sizeof(OutSwizzleBuf) <= sizeof(SharedMem::ReusedXVOutSwizzleBuf::XV), + "need to use split output to avoid excessive shared memory usage"); + + __device__ inline XBuffer& xBuf(uint32_t i) { return reusedXVOutSwizzleBuf[i].xv.x; } + + __device__ inline VBuffer& vBuf(uint32_t i) { return reusedXVOutSwizzleBuf[i].xv.v; } +#if !SWAP_AB + __device__ inline VTBuffer& vtBuf(uint32_t i) { return reusedXVOutSwizzleBuf[i].xv.vt; } +#endif + __device__ inline OutSwizzleBuf& outSwizzleBuf(uint32_t i) { + return reusedXVOutSwizzleBuf[i].outSwizzle; + } + + using QBuffer = Vec, nbQParts>; + QBuffer q; // For gmma math. Conversion done if needed. + + // @fixme: move these into reusedXVOutSwizzleBuf +#if SWAP_AB + ShmQWiseVec xColMax[nbXBuf]; + ShmQWiseVec xColSum[nbXBuf][gemm0NbWarps]; +#else + ShmQWiseVec xRowMax[nbXBuf]; + ShmQWiseVec xRowSum[nbXBuf]; +#endif + + ShmQWiseVec gemm0CurrentSeqMax; + // col sum and max for the current gemm1 acc. Use shared memory to save some registers. register + // storage will be 8x duplicate for swapAB and 4x duplicate for non-swapAB. + ShmQWiseVec gemm1AccColMax; + ShmQWiseVec gemm1AccColSum; + +#if USE_PAGED_KV_CACHE + static constexpr uint32_t nbPagesPerTile = + gemm0CtaTileNbTokens >= tokensPerPage ? exactDiv(gemm0CtaTileNbTokens, tokensPerPage) : 1; + Vec pages[2]; // one for K and one for V +#endif + + // mem barriers + + CtaBarrierPair qBar; + CtaBarrierPair kBar[nbKBuf]; + CtaBarrierPair vBar[nbVBuf]; +#if !SWAP_AB + CtaBarrierPair vtBar[nbVBuf]; +#endif + CtaBarrierPair xBar[nbXBuf]; + + // used internally in the gemm0 warp group + // @fixme: use separate arrive and wait for all usage + CtaBarrier gemm0WarpGrpBar; + + // used internally in the gemm1 warp group + // @fixme: use separate arrive and wait for all usage + CtaBarrier gemm1WarpGrpBar; + + bool isLastCta; +}; + +CUBIN_EXPORT __device__ constexpr uint32_t smemSize = sizeof(SharedMem); +#ifdef __CUDA_ARCH__ +static_assert(smemSize < kMAX_SMEM_SIZE); +#endif + +constexpr uint32_t nbQLdWarps = needInputCvt ? nbIOWarps - 2 : 1; +constexpr uint32_t nbQLdThrds = warp_size * nbQLdWarps; + +#if CACHE_ELEM_ENUM == 0 || CACHE_ELEM_ENUM == 2 +template +struct F16QToF8Converter { + static_assert(inputElemSize == 2); + using F16Vec = Vec; +#if CACHE_ELEM_ENUM == 0 + using ShmVec = F16Vec; +#elif CACHE_ELEM_ENUM == 2 + using F8Vec = Vec; + using ShmVec = F8Vec; +#endif + + static constexpr uint32_t grainsPerPaddedInputHead = exactDiv(paddedInputHeadBytes, grainBytes); + static constexpr uint32_t grainsPerPaddedInputQHeadGrp = grainsPerPaddedInputHead * headGrpSize; +#if !(SPEC_DEC) + static constexpr uint32_t totalGrains = grainsPerPaddedInputQHeadGrp * beamWidth; +#else + static_assert(beamWidth == 1); + static constexpr uint32_t totalGrains = grainsPerPaddedInputQHeadGrp * inputTokensPerCta; +#endif + static constexpr uint32_t nbIters = divUp(totalGrains, nbThrds); + + using RegData = Vec; + + static __device__ RegData load(uint32_t tid, TinyPtr const& src, + uint32_t const nbKHeads /*for beam search and spec dec*/, + uint32_t nbTokens); + static __device__ void store(uint32_t tid, SharedMem::QBuffer& dst, RegData const& data); +}; +#endif // CACHE_ELEM_ENUM + +struct KVTilePartLoader { + static constexpr uint32_t nbParts = cacheHeadNbParts; + static constexpr uint32_t partElems = exactDiv(headElems, nbParts); + +#if USE_PAGED_KV_CACHE + static_assert(gemm0CtaTileNbTokens % tokensPerPage == 0 || + tokensPerPage % gemm0CtaTileNbTokens == 0); + static constexpr uint32_t nbPagesPerTile = SharedMem::nbPagesPerTile; +#endif + + uint32_t const nbKHeads; + KVCacheList const& cacheList; + uint32_t const idxReq; + uint32_t const idxHeadGrp; + + CUtensorMap const& tensorMap; +#if USE_PAGED_KV_CACHE + uint32_t const nbPages; // for bound check + Vec& pages; + uint32_t idxTileRef; // idxTile used to load the pages +#endif + uint32_t const baseOffset; + + __device__ KVTilePartLoader(bool isK, uint32_t nbKHeads, + KVCacheList const& cacheList, uint32_t idxReq, + uint32_t idxHeadGrp, CUtensorMap const& tensorMap +#if USE_PAGED_KV_CACHE + , + uint32_t nbPages, Vec& pageBuf +#endif + ); + // tensorMap is for one whole page ([nbKHeads*tokensPerPage][headElems]) or whole cache + template + __device__ void loadData( + Array2D& dst, + uint32_t idxTile, uint32_t idxPart, CtaBarrier& bar); + + __device__ void loadPages(uint32_t idxTile); + __device__ GMemKVCacheHead& getHead(uint32_t pos); +}; + +using GmmaAccCoreMat = Array2D; +template +using GmmaAcc = + Array2D; + +inline constexpr uint32_t gemm0M = (swapAB ? gemm0CtaTileNbTokens : ctaNbQHeads); +inline constexpr uint32_t gemm0N = (swapAB ? ctaNbQHeads : gemm0CtaTileNbTokens); + +using Gemm0Acc = GmmaAcc; + +#if SWAP_AB +using RegColWiseVec = Vec, Gemm0Acc::cols>; +using UniformNeedRescaleMask = Vec; +using RegSeqWiseVec = RegColWiseVec; +#else +using RegRowWiseVec = Vec, Gemm0Acc::rows>; +using UniformNeedRescaleMask = + Vec; +using RegSeqWiseVec = RegRowWiseVec; +#endif + +#if SPEC_DEC + +__device__ inline uint32_t getInputSeqLen(SpecDecParams const& params, uint32_t idxReq) { + return (params.qCuSeqLens == nullptr) ? params.qSeqLen + : params.qCuSeqLens[idxReq + 1] - params.qCuSeqLens[idxReq]; +} + +__device__ inline uint32_t getInputTokOffset(SpecDecParams const& params, uint32_t idxReq) { + return (params.qCuSeqLens == nullptr) ? params.qSeqLen * idxReq : params.qCuSeqLens[idxReq]; +} + +struct SpecDec { + static inline constexpr uint32_t tileSize = gemm0CtaTileNbTokens; + static inline constexpr uint32_t ctaMaxQSeqLen = (ctaNbQHeads / headGrpSize); + using TileMaskRow = Vec; + + __device__ inline SpecDec(SpecDecParams const& params, uint32_t idxReq, uint32_t idxInputSubSeq, + uint32_t seqLen) + : params(params), idxInputSubSeq(idxInputSubSeq), seqLen(seqLen) { + inputSeqLen = getInputSeqLen(params, idxReq); + baseOffset = divUp(params.qSeqLen, 32U) * + (getInputTokOffset(params, idxReq) + ctaMaxQSeqLen * idxInputSubSeq); + } + + __device__ inline uint32_t unmaskedSeqLen() const { return seqLen - inputSeqLen; } + + __device__ inline bool needMask(uint32_t idxTile, uint32_t idxQTokInCta) const { + return tileSize * (idxTile + 1) > unmaskedSeqLen() && + ctaMaxQSeqLen * idxInputSubSeq + idxQTokInCta < inputSeqLen && params.mask != nullptr; + } + + __device__ inline int32_t maskColBeg(uint32_t idxTile) const { + int32_t const convergedSeqLen = int32_t(unmaskedSeqLen()); + return static_cast(exactDiv(tileSize, 32) * idxTile) - + static_cast(divUp(convergedSeqLen, 32)); + } + + __device__ inline TileMaskRow loadTileMaskRow(uint32_t idxTile, uint32_t idxQTokInCta) const { + assert(needMask(idxTile, idxQTokInCta)); + constexpr uint32_t nbOrigElems = TileMaskRow::size + 1; + Vec orig; + + int32_t const cols = divUp(params.qSeqLen, 32); + uint32_t const rowOffset = baseOffset + idxQTokInCta * cols; + int32_t const colBeg = maskColBeg(idxTile); +#pragma unroll + for (int32_t i = 0; i < int32_t(nbOrigElems); i++) { + int32_t const idx = colBeg + i; + orig[i] = inRange(idx, 0, cols) ? params.mask[rowOffset + idx] : (idx < 0 ? ~0U : 0U); + } + TileMaskRow mask; + uint32_t const shift = (32 - unmaskedSeqLen() % 32) % 32; +#pragma unroll + for (uint32_t i = 0; i < TileMaskRow::size; i++) { + asm("shf.r.clamp.b32 %0, %1, %2, %3;\n" + : "=r"(mask[i]) + : "r"(orig[i]), "r"(orig[i + 1]), "r"(shift)); + } + return mask; + } + + SpecDecParams const& params; + uint32_t const idxInputSubSeq; + uint32_t const seqLen; + uint32_t inputSeqLen; + uint32_t baseOffset; +}; + +__device__ void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec, +#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE + int32_t tok0WinBeg, +#endif + uint32_t cacheSeqLen, uint32_t idxTile, uint32_t warpRank); +#endif + +#if SWAP_AB +__device__ RegColWiseVec computeWarpGrpColMax_sync(CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax, + Gemm0Acc const& src); +__device__ void warpGrpApplyMask(uint32_t warpRank, Gemm0Acc& acc, uint32_t validRowBeg, + uint32_t validRowEnd); +__device__ void warpGrpOnlineSoftmax(Gemm0Acc& acc, RegColWiseVec const& colMax); +__device__ RegColWiseVec computeWarpColSum(Gemm0Acc& src); +__device__ void storeGemm0AccToShm(uint32_t warpRank, uint32_t lane, SharedMem::XBuffer& smemX, + CtaBarrier& barConsumed, Gemm0Acc const& acc); +__device__ RegColWiseVec loadShmColWiseVecWithDup(ShmQWiseVec const& smemVec); +__device__ RegColWiseVec loadGmemColWiseVecWithDup(ShmQWiseVec const& gmemVec, uint32_t bound); +#else +__device__ RegRowWiseVec computeWarpGrpRowMax_sync(uint32_t warpRank, ShmQWiseVec& smemColMax, + Gemm0Acc const& src); +__device__ void warpGrpApplyMask(Gemm0Acc& acc, uint32_t validColBeg, uint32_t validColEnd); +__device__ void warpGrpOnlineSoftmax(Gemm0Acc& acc, RegRowWiseVec const& colMax); +__device__ RegRowWiseVec computeWarpRowSum(Gemm0Acc& src); +__device__ void storeGemm0AccToShm(uint32_t warpRank, uint32_t lane, SharedMem::XBuffer& smemX, + CtaBarrier& barConsumed, Gemm0Acc const& acc); +__device__ RegRowWiseVec loadShmRowWiseVecWithDup(uint32_t warpRank, ShmQWiseVec const& smemVec); +__device__ void storeShmRowWiseVec(uint32_t warpRank, ShmQWiseVec& smemVec, + RegRowWiseVec const& regVec); +#endif + +using RegMatAFrag = Array2D, 1, 2>; +constexpr uint32_t gemm1NbGmmaInstK = exactDiv(gemm1CtaTileNbTokens, gmma::instK); + +#if SWAP_AB +constexpr uint32_t gemm1NbGmmaInstM = exactDiv(headElems, gmma::instM); +__device__ Vec loadVTileTransposed(uint32_t warpRank, uint32_t lane, + SharedMem::VBuffer const& smemV, + uint32_t idxGmmaInstK); +using Gemm1Acc = GmmaAcc; +__device__ void rescaleGemm1AccForNewColMax_sync(uint32_t warpRank, ShmQWiseVec const& shmXColMax, + ShmQWiseVec const (&shmXColSum)[gemm0NbWarps], + ShmQWiseVec& shmAccColMax, Gemm1Acc& acc, + ShmQWiseVec& shmAccColSum, + CtaBarrier& gemm1WarpGrpBar); +template +__device__ void finalizeAndWriteOut_sync( + uint32_t threadRank, uint32_t warpRank, DstHead* dst, SharedMem::OutSwizzleBuf& swizzleBuf, + Gemm1Acc& acc, float xvoScale, CtaBarrier& warpGrpBar, ShmQWiseVec const& accColSum, + ShmQWiseVec const& accColMax, ShmQWiseVec const* attentionSinksVec, + uint32_t nbKHeads = 0 /* only for final result in spec dec. */); +#else +__device__ void transposeVTile(uint32_t warpRank, uint32_t lane, SharedMem::VTBuffer& dst, + SharedMem::VBuffer const& src); +using Gemm1Acc = GmmaAcc; +__device__ void rescaleGemm1AccForNewRowMax_sync(uint32_t warpRank, ShmQWiseVec const& shmXRowMax, + ShmQWiseVec const(&shmXRowSum), + ShmQWiseVec& shmAccRowMax, Gemm1Acc& acc, + ShmQWiseVec& shmAccRowSum); +template +__device__ void finalizeAndWriteOut_sync( + uint32_t warpRank, DstHead* dst, SharedMem::OutSwizzleBuf& swizzleBuf, Gemm1Acc& acc, + float xvoScale, ShmQWiseVec const& accColSum, + uint32_t nbKHeads /* only for final result in spec dec. set to 1 for workspace*/, + uint32_t ctaNbValidTokens); +#endif + +inline constexpr uint32_t ropeNbPairsPerThrdImpl(uint32_t nbThrds) { + auto const val = divUp(exactDiv(validElemsPerHead, 2), nbThrds); + assert(val <= 32); + return val <= 2 ? val : (val <= 4 ? 4 : (val <= 8 ? 8 : (val <= 16 ? 16 : 32))); +} + +template +inline constexpr uint32_t ropeNbPairsPerThrd = ropeNbPairsPerThrdImpl(nbThrds); + +template +__device__ Vec, ropeNbPairsPerThrd> loadHead( + Vec const& head, uint32_t tid); +template +__device__ mha::conditional_t, 2>, + Vec, nbPairsPerThrd>> +applyRoPE(Vec, nbPairsPerThrd> const& data, + Vec, nbPairsPerThrd> const& ropeCosSin); +template +__device__ void storeRotatedPairsForKV( + GMemCacheHead& dst, + mha::conditional_t>, 2>, + Vec, ropeNbPairsPerThrd>> const& src, + uint32_t tid); +template +__device__ void storeRotatedPairsForQ( + SharedMem::QBuffer& dst, + mha::conditional_t>, 2>, + Vec, ropeNbPairsPerThrd>> const& src, + uint32_t row, uint32_t tid); + +class ScratchMem { + public: + struct alignas(8) SumMax { + float sum; + float max; + }; + + using ColWiseVec = Vec; + + HOST_DEVICE_FUNC ScratchMem(void* scratch, uint32_t maxTotalNbSubSeq, uint32_t nbInputSeqSplit) + : mScratch{static_cast(scratch)} { + uint32_t const nbChunks = maxTotalNbSubSeq * nbInputSeqSplit; + Segmenter segmenter; + constexpr uint32_t alignment = sizeof(Vec); + mRowSumMax = segmenter.template newSeg(nbChunks, alignment); + mTokens = segmenter.template newSeg>(nbChunks, alignment); + } + + HOST_DEVICE_FUNC TinyPtr rowSumMax() const { return makePtr(mRowSumMax); } + + HOST_DEVICE_FUNC TinyPtr> tokens() const { + return makePtr>(mTokens); + } + + private: + template + HOST_DEVICE_FUNC TinyPtr makePtr(uint32_t offset) const { + return TinyPtr{mScratch, offset}.template cast(); + } + + private: + mha::byte* mScratch; + // offsets + uint32_t mRowSumMax; + uint32_t mTokens; +}; + +struct MultiBlockSMem { + using ColWiseVec = ScratchMem::ColWiseVec; + static constexpr uint32_t nbBuf = useSpecDec ? 2 : 4; + static constexpr uint32_t nbIOWarps = nbBuf; + using Elem = InputElem; + using Head = Vec; + Vec, nbBuf> tokens; + Vec rowSumMax; + Vec barriers; +}; + +#ifndef NDEBUG +namespace dbg { +template +__device__ void printAcc(CtaBarrier& warpGrpBar, uint32_t warpRank, + Array2D const& acc) { + for (int m = 0; m < nbGmmaInstM; m++) { + for (int w = 0; w < 4; w++) { + if (warpRank == w) { + for (int a = 0; a < 2; a++) { + for (int b = 0; b < 8; b++) { + for (int n = 0; n < nbGmmaInstNBase; n++) { + for (uint32_t i = 0; i < 4; i++) { + if (laneId() == b * 4 + i) { + printf("%f, %f, ", acc(m, n)(a, 0), acc(m, n)(a, 1)); + } + __syncwarp(); + } + } + if (laneId() == 0) { + printf("\n"); + } + __syncwarp(); + } + if (laneId() == 0) { + printf("\n"); + } + __syncwarp(); + } + } + warpGrpBar.arrive_and_wait(); + } + } +} + +__device__ void printShmColWiseVec(ShmQWiseVec const& vec) { + for (uint32_t i = 0; i < vec.size; i++) { + printf("%f, ", vec[i]); + } + printf("\n"); +} + +template +__device__ void printArray2D(Array2D const& src) { + for (uint32_t i = 0; i < rows; i++) { + for (uint32_t j = 0; j < cols; j++) { + T const val = src.template at(i, j); + for (uint32_t k = 0; k < exactDiv(sizeof(T), sizeof(Elem)); k++) { + printf("%f, ", float(reinterpret_cast(&val)[k])); + } + } + printf("\n"); + } +} +} // namespace dbg +#endif + +CUBIN_EXPORT __device__ constexpr XQAKernelType kernelType = + XQAKernelType::kHOPPER_WARP_SPECIALIZED; + +CUBIN_EXPORT __global__ +#ifdef NDEBUG +#if !OPTIMIZE_FOR_LATENCY +__launch_bounds__(128 * 3, headElems* ctaNbQHeads <= 128 * 16 ? 3 : 2) +#else +__launch_bounds__(128 * 3) +#endif +#else + __launch_bounds__(128 * 3, 1) +#endif + void kernel_mha( + uint32_t const nbKHeads, +#if SLIDING_WINDOW + uint32_t const slidingWinSize, +#endif + float const qScale, + OutputHead* __restrict__ const output, // [nbReq][beamWidth][nbQHeads] +#if LOW_PREC_OUTPUT + float const* const rcpOutScale, +#endif +#if USE_INPUT_KV + IOHead const* __restrict__ const qkv, // [nbReq][beamWidth][nbQHeads+nbKHeads+nbVHeads], +#if ROPE_STYLE != 0 + Vec const* __restrict__ const ropeCosSin, // [maxNbPosEmb] +#endif +#else + IOHead const* __restrict__ const q, // [nbReq][beamWidth][nbQHeads], +#endif + float const* attentionSinks, // [headGrpSize] + KVCacheList const cacheList, +#if USE_BEAM_SEARCH + BeamSearchParams const beamSearchParams, +#endif + uint32_t const batchSize, + float const* __restrict__ const kvCacheScale, // Device memory scalar. Same scale for K and + // V cache. Used only for int8/fp8 KV cache. +#if PAGED_KV_CACHE_LAYOUT == 1 + __grid_constant__ CUtensorMap const tensorMapVLLMK, + __grid_constant__ CUtensorMap const tensorMapVLLMV, +#else + __grid_constant__ CUtensorMap const tensorMap, +#endif +#if SPEC_DEC + SpecDecParams const specDecParams, +#endif + uint32_t* __restrict__ const semaphores = + nullptr, // [nbReq][nbKHeads][divUp(specDecParams.qSeqLen, inputTokensPerCta)] + void* __restrict__ const scratch = nullptr) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) && \ + (IS_SUPPORTED_F16_CASE || CACHE_ELEM_ENUM == 2) && BEAM_WIDTH == 1 + uint32_t const idxReq = blockIdx.z / nbKHeads; +#if SPEC_DEC + uint32_t const reqInputTokBeg = getInputTokOffset(specDecParams, idxReq); + uint32_t const reqInputTokEnd = getInputTokOffset(specDecParams, idxReq + 1); + uint32_t const nbInputSeqSplit = gridDim.x; + assert(nbInputSeqSplit == divUp(specDecParams.qSeqLen, inputTokensPerCta)); +#else + uint32_t const reqInputTokBeg = idxReq; + uint32_t const reqInputTokEnd = idxReq + 1; + constexpr uint32_t nbInputSeqSplit = 1; + assert(gridDim.x == nbInputSeqSplit); +#endif + uint32_t const idxHeadGrp = blockIdx.z % nbKHeads; // inside one request + assert(gridDim.z == nbKHeads * batchSize); + uint32_t const cacheSeqLen = getCacheSeqLen(cacheList, idxReq); + static_assert(gemm0CtaTileNbTokens == gemm1CtaTileNbTokens); + constexpr uint32_t tileSize = gemm0CtaTileNbTokens; +#if SPEC_DEC + uint32_t const idxInputSubSeq = blockIdx.x; + uint32_t const inputSeqLen = reqInputTokEnd - reqInputTokBeg; + uint32_t const ctaTokOffset = inputTokensPerCta * idxInputSubSeq; + uint32_t const ctaNbValidTokens = + mha::min(uint32_t{inputTokensPerCta}, inputSeqLen - ctaTokOffset); + + if (ctaTokOffset >= inputSeqLen) { + return; + } +#else + uint32_t const idxInputSubSeq = 0; + uint32_t const inputSeqLen = 1; + uint32_t const ctaTokOffset = 0; + uint32_t const ctaNbValidTokens = 1; +#endif +#if SLIDING_WINDOW && SPEC_DEC && !IS_SPEC_DEC_TREE + // get the actual start position depending on ctaTokOffset, which is the draft token position per + // CTA + uint32_t const tok0SeqLen = cacheSeqLen - inputSeqLen + 1 + ctaTokOffset; + int32_t const tok0WinBeg = int32_t(tok0SeqLen) - int32_t(slidingWinSize); + uint32_t const nbTotalSkipTokens = mha::max(0, tok0WinBeg); +#elif SLIDING_WINDOW + bool const rtIsReallySliding = (cacheSeqLen > slidingWinSize); + // if SPEC_DEC && SLIDING_WINDOW && IS_SPEC_DEC_TREE, it should not do sliding + assert(!SPEC_DEC || !rtIsReallySliding); + uint32_t const nbTotalSkipTokens = rtIsReallySliding ? cacheSeqLen - slidingWinSize : 0; +#else + constexpr bool rtIsReallySliding = false; + constexpr uint32_t nbTotalSkipTokens = 0; +#endif + uint32_t const nbSkipLeadingTiles = nbTotalSkipTokens / tileSize; + uint32_t const tile0NbSkipTokens = nbTotalSkipTokens % tileSize; + +#if USE_BEAM_SEARCH + uint32_t const ctxCacheSeqLen = getCtxCacheSeqLen(beamSearchParams, idxReq); + uint32_t const nbCtxKTiles = useKVCache ? ctxCacheSeqLen / gemm0CtaTileNbTokens : 0; + uint32_t const nbDivergentKTiles = + useKVCache + ? divUp(cacheSeqLen - gemm0CtaTileNbTokens * nbCtxKTiles, beamSearchGemm0CtaTileNbTokens) + : 0; + uint32_t const nbKTiles = nbCtxKTiles + nbDivergentKTiles; + uint32_t const nbVTiles = nbKTiles; +#else + uint32_t const nbTiles = useKVCache ? divUp(cacheSeqLen, tileSize) : 0; + // uint32_t const nbKTiles = nbTiles; + // uint32_t const nbVTiles = nbTiles; + uint32_t const nbTilesInUse = nbTiles - nbSkipLeadingTiles; +#endif + uint32_t const maxNbSubSeq = gridDim.y; + uint32_t const idxSubSeq = blockIdx.y; + bool const isMultiBlockMode = (maxNbSubSeq > 1 && nbTilesInUse >= multiBlockMinNbTiles); + uint32_t const idxKTileInit = nbSkipLeadingTiles + idxSubSeq; + uint32_t const idxVTileInit = idxKTileInit; + uint32_t const nbSubSeq = + isMultiBlockMode ? mha::min(nbTilesInUse / multiBlockMinNbTilesPerCta, maxNbSubSeq) : 1; + static_assert(multiBlockMinNbTiles >= multiBlockMinNbTilesPerCta * 2); + assert(isMultiBlockMode == (nbSubSeq > 1)); + if (idxSubSeq >= nbSubSeq) { + return; + } + uint32_t const ctaInputTokBeg = reqInputTokBeg + ctaTokOffset; + auto const warpIdx = getWarpIdx(uint3{128, 1, 3}); + auto const wid = warpIdx.z * 4 + warpIdx.x; +#if PAGED_KV_CACHE_LAYOUT == 1 + if (wid == 0 && warpElectSync()) { + tma::prefetchTensorMap(tensorMapVLLMK); + tma::prefetchTensorMap(tensorMapVLLMV); + } +#else + if (wid == 0 && warpElectSync()) { + tma::prefetchTensorMap(tensorMap); + } +#endif + extern __shared__ char smemByteBuf[]; + assert(dynamicSmemSize() >= sizeof(SharedMem)); + SharedMem& smem = *reinterpret_cast(&smemByteBuf[0]); + + constexpr uint32_t nbBuffers = 2; + static_assert(nbBuffers == SharedMem::nbKBuf && nbBuffers == SharedMem::nbVBuf && + nbBuffers == SharedMem::nbXBuf); + if (wid < nbBuffers) { + if (warpElectSync()) { + smem.kBar[wid].initialize(gemm0NbThrds, gemm0NbThrds + warp_size); + smem.vBar[wid].initialize(gemm1NbThrds, gemm1NbThrds + warp_size); +#if !SWAP_AB + smem.vtBar[wid].initialize(gemm1NbThrds * 2, gemm1NbThrds * 2); +#endif + smem.xBar[wid].initialize(gemm0NbThrds + gemm1NbThrds, gemm0NbThrds + gemm1NbThrds); + } + } else if (wid == nbBuffers) { + if (warpElectSync()) { + smem.qBar.initialize(gemm0NbThrds + nbQLdThrds, gemm0NbThrds + nbQLdThrds); + init(&smem.gemm0WarpGrpBar, gemm0NbThrds); + init(&smem.gemm1WarpGrpBar, gemm1NbThrds); + } + } + __syncthreads(); + +#if USE_PAGED_KV_CACHE + uint32_t const nbPages = divUp(cacheSeqLen, tokensPerPage); +#endif + + constexpr bool isKVCacheQuantized = (cacheElemSize < 2); + assert(idxKTileInit < nbTiles); + uint32_t const nbIters = divUp(nbTiles - idxKTileInit, nbSubSeq); + assert(nbIters >= 1); + + constexpr uint32_t gmmaInstK = gmma::instK; + constexpr uint32_t grainsPerInstK = exactDiv(sizeof(MathElem) * gmmaInstK, grainBytes); + + if (warpIdx.z == 0) { +#if SPEC_DEC + SpecDec const specDec{specDecParams, idxReq, idxInputSubSeq, cacheSeqLen}; +#endif + + // QK gemm + constexpr uint32_t nbGmmaInstM = exactDiv(gemm0CtaTileNbTokens, gmma::instM); + using Acc = GmmaAcc; + + unused(smem.qBar.consumed.arrive()); + for (auto& b : smem.kBar) { + unused(b.consumed.arrive()); + } + + float const qkScale = + qScale * (isKVCacheQuantized ? kvCacheScale[0] : 1.f) * + rsqrtf(validElemsPerHead); // qkScale is applied onto Q*K.T before softmax. + uint32_t const warpRank = warpIdx.x; + + // init once per sequence. It also works as global colMax across iterations. + if (threadIdx.x < ctaNbQHeads) { + smem.gemm0CurrentSeqMax[threadIdx.x] = safeInitRowMax; + } + smem.gemm0WarpGrpBar.arrive_and_wait(); + + smem.qBar.produced.arrive_and_wait(); +#if DBG_PRINT + if (threadIdx.x == 0) { + printf("q:\n"); + dbg::printArray2D<__nv_fp8_e4m3, true>(smem.q[0]); + } +#endif + + auto const matDescQBase = + gmma::makeMatDesc(nullptr, 0, SharedMem::QBuffer::Elem::rowBytes * 8, + gmma::getSwizzleMode(SharedMem::QBuffer::Elem{})) + .raw(); + for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++) { + uint32_t const idxKTile = idxKTileInit + idxIter * nbSubSeq; + assert(idxKTile < nbTiles); + Acc acc; // no need to initialize. GMMA allows us to ignore acc initial values. + gmma::fence(); + static_assert(cacheHeadNbParts == nbQParts); +#pragma unroll + for (uint32_t idxPart = 0; idxPart < cacheHeadNbParts; idxPart++) { + auto const idxKBuf = (idxIter * cacheHeadNbParts + idxPart) % SharedMem::nbKBuf; + auto& kBuf = smem.k[idxKBuf]; + auto& kBar = smem.kBar[idxKBuf]; + static_assert(SharedMem::KBuffer::rows % 8 == 0); + auto const matDescKBase = + gmma::makeMatDesc(nullptr, 0, SharedMem::KBuffer::rowBytes * 8, &smem.k[0], + gmma::getSwizzleMode(SharedMem::KBuffer{})) + .raw(); + assert(matDescKBase == gmma::makeMatDesc(nullptr, 0, SharedMem::KBuffer::rowBytes * 8, + gmma::getSwizzleMode(SharedMem::KBuffer{})) + .raw()); + arrive_tx_and_wait(kBar.produced, exactDiv(sizeof(SharedMem::KBuffer), gemm0NbThrds)); + // if (threadIdx.x == 0) { + // printf("************* part %u *******\n", idxPart); + // printf("q:\n"); + // dbg::printArray2D<__nv_fp8_e4m3, true>(smem.q[idxPart]); + // printf("k:\n"); + // dbg::printArray2D<__nv_fp8_e4m3, true>(kBuf); + // } + constexpr uint32_t nbGmmaInstK = exactDiv(cacheHeadPartElems, gmmaInstK); +#pragma unroll + for (uint32_t k = 0; k < nbGmmaInstK; k++) { + bool const accHasVal = (idxPart != 0 || k != 0); + auto const matDescQ = addAddr(matDescQBase, &smem.q[idxPart](0, grainsPerInstK * k)); +#pragma unroll + for (uint32_t m = 0; m < nbGmmaInstM; m++) { + auto const matDescK = addAddr(matDescKBase, &kBuf(64 * m, grainsPerInstK * k)); +#if SWAP_AB + gmma::mma_async_shmA( + reinterpret_cast(acc(m, 0)), + matDescK, matDescQ, accHasVal); +#else + gmma::mma_async_shmA( + reinterpret_cast(acc(m, 0)), + matDescQ, matDescK, accHasVal); +#endif + } + } + gmma::commit_group(); + //@fixme: use two sets of acc and let gmma_async overlap with softmax. But this will let + // tile0_softmax + // wait for + // k loading of tile1 and may harm perf for short-seq cases. + gmma::wait_group<0>(); + unused(kBar.consumed.arrive()); + } +#if !defined(NDEBUG) && DBG_PRINT + dbg::printAcc(smem.gemm0WarpGrpBar, warpRank, acc); +#endif + // apply qkScale + acc = acc * qkScale; + // apply mask +#if SPEC_DEC + warpGrpApplyMask(acc, specDec, +#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE + tok0WinBeg, +#endif + cacheSeqLen, idxKTile, warpRank); +#else + bool const isFirstTile = (idxKTile == nbSkipLeadingTiles); + bool const needMaskLeading = (rtIsReallySliding && isFirstTile && tile0NbSkipTokens > 0); + bool const isLastTile = (idxKTile + 1 == nbTiles); + bool const needMaskTrailing = isLastTile && cacheSeqLen % tileSize != 0; + if (needMaskLeading || needMaskTrailing) { + uint32_t const validTokenBeg = needMaskLeading ? tile0NbSkipTokens : 0; + uint32_t const validTokenEnd = (needMaskTrailing ? cacheSeqLen % tileSize : tileSize); + if (validTokenBeg > 0 || validTokenEnd < tileSize) { +#if SWAP_AB + warpGrpApplyMask(warpRank, acc, validTokenBeg, validTokenEnd); +#else + warpGrpApplyMask(acc, validTokenBeg, validTokenEnd); +#endif + } + } +#endif + // update colMax in shared mem and get a register copy +#if SWAP_AB + RegColWiseVec const colMax = + computeWarpGrpColMax_sync(smem.gemm0WarpGrpBar, smem.gemm0CurrentSeqMax, acc); + warpGrpOnlineSoftmax(acc, colMax); +#else + RegRowWiseVec const rowMax = + computeWarpGrpRowMax_sync(warpRank, smem.gemm0CurrentSeqMax, acc); + warpGrpOnlineSoftmax(acc, rowMax); +#endif + + // @fixme: may need fp32->fp8->fp32 before doing sum. +#if SWAP_AB + RegColWiseVec const warpColSum = computeWarpColSum(acc); +#else + RegRowWiseVec const rowSum = computeWarpRowSum(acc); +#endif + + // map 1 to fp8_max before conversion to fp8 + acc = acc * kE4M3_MAX; + + uint32_t const idxXBuf = idxIter % SharedMem::nbXBuf; + auto& xBar = smem.xBar[idxXBuf]; + // @fixme: for fp16/bf16, try not to transpose acc here, and leave it to the next GEMM. +#if SWAP_AB + storeGemm0AccToShm(warpRank, laneId(), smem.xBuf(idxXBuf), xBar.consumed, acc); + // store colMax and warpColSum + auto const lane = laneId(); + if (lane < 4) { + auto& xColMax = smem.xColMax[idxXBuf]; + auto& xColSum = smem.xColSum[idxXBuf][warpRank]; +#pragma unroll + for (uint32_t n = 0; n < colMax.size; n++) { +#pragma unroll + for (uint32_t j = 0; j < 2; j++) { + if (warpRank == 0) { + xColMax[8 * n + 2 * lane + j] = colMax[n][j]; + } + xColSum[8 * n + 2 * lane + j] = warpColSum[n][j]; + } + } + } +#else + storeGemm0AccToShm(warpRank, laneId(), smem.xBuf(idxXBuf), xBar.consumed, acc); + storeShmRowWiseVec(warpRank, smem.xRowMax[idxXBuf], rowMax); + storeShmRowWiseVec(warpRank, smem.xRowSum[idxXBuf], rowSum); +#endif + + __syncwarp(); + // the release semantics of arrive does not work for async consumers like gmma. additional + // fence is needed. + asm volatile("fence.proxy.async.shared::cta;\n"); + unused(xBar.produced.arrive()); + } + unused(smem.qBar.consumed.arrive()); + } else if (warpIdx.z == 1) { + // XV GEMM + for (auto& b : smem.vBar) { + unused(b.consumed.arrive()); + } +#if !SWAP_AB + for (auto& b : smem.vtBar) { + unused(b.consumed.arrive()); + } +#endif + for (auto& b : smem.xBar) { + unused(b.consumed.arrive()); + } + + if (threadIdx.x < smem.gemm1AccColMax.size) { + auto const idx = threadIdx.x; + smem.gemm1AccColMax[idx] = safeInitRowMax; + smem.gemm1AccColSum[idx] = 0; + } + smem.gemm1WarpGrpBar.arrive_and_wait(); + + uint32_t const warpRank = warpIdx.x; + + constexpr float xScale = 1.f / kE4M3_MAX; +#if LOW_PREC_OUTPUT + float const oScale = rcpOutScale[0]; +#else + constexpr float oScale = 1.F; +#endif + float const xvoScale = xScale * (isKVCacheQuantized ? kvCacheScale[0] : 1.f) * oScale; + + Gemm1Acc acc{}; // init to zeros to avoid runtime checking for first gmma instruction. + gmma::fence(); + + static_assert(gemm0CtaTileNbTokens == gemm1CtaTileNbTokens, "not implemented"); + for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++) { + uint32_t idxVTile = idxVTileInit + idxIter * nbSubSeq; + auto const idxVBuf = idxIter % SharedMem::nbVBuf; + auto const idxXBuf = idxVBuf; + auto& vBar = smem.vBar[idxVBuf]; + arrive_tx_and_wait(vBar.produced, exactDiv(sizeof(SharedMem::VBuffer), gemm1NbThrds)); + auto const& vBuf = smem.vBuf(idxVBuf); +#if !SWAP_AB + CtaBarrierPair& vtBar = smem.vtBar[idxVBuf]; + auto& vtBuf = smem.vtBuf(idxVBuf); + vtBar.consumed.arrive_and_wait(); + transposeVTile(warpRank, laneId(), vtBuf, vBuf); + vBar.consumed.arrive(); + vtBar.produced.arrive(); +#endif + auto& xBar = smem.xBar[idxXBuf]; + xBar.produced.arrive_and_wait(); +#if !defined(NDEBUG) && DBG_PRINT +#if SWAP_AB + if (threadIdx.x == 0) { + printf("colMax:\n"); + for (int i = 0; i < ctaNbQHeads; i++) { + printf("%f, ", smem.xColMax[idxXBuf][i]); + } + printf("\n"); + printf("colSum:\n"); + for (int n = 0; n < 4; n++) { + for (int i = 0; i < ctaNbQHeads; i++) { + printf("%f, ", smem.xColSum[idxXBuf][n][i]); + } + printf("\n"); + } + printf("\n"); + printf("X:\n"); + for (int i = 0; i < ctaNbQHeads; i++) { + for (int j = 0; j < gemm0CtaTileNbTokens; j++) { + auto const& elemsPerXPart = (cacheElemsPerGrain * grainsPerXPart); + auto const e = reinterpret_cast&>( + smem.xBuf(idxXBuf)[j / elemsPerXPart].template at( + i, j % elemsPerXPart / cacheElemsPerGrain))[j % cacheElemsPerGrain]; + printf("%.2f, ", float(e)); + if (j % 16 == 15) { + printf("| "); + } + } + printf("\n\n"); + } + } + smem.gemm1WarpGrpBar.arrive_and_wait(); +#else + if (blockIdx.y == 1 && threadIdx.x == 0) { + printf("rowMax:\n"); + for (int i = 0; i < ctaNbQHeads; i++) { + printf("%f, ", smem.xRowMax[idxXBuf][i]); + } + printf("\n"); + printf("rowSum:\n"); + for (int i = 0; i < ctaNbQHeads; i++) { + printf("%f, ", smem.xRowSum[idxXBuf][i]); + } + printf("\n"); + } + smem.gemm1WarpGrpBar.arrive_and_wait(); +#endif +#endif + +#if SWAP_AB + // @fixme: if first tile, no need to rescale acc. For persistent CTA, just re-initialize acc + // instead. + rescaleGemm1AccForNewColMax_sync(warpRank, smem.xColMax[idxXBuf], smem.xColSum[idxXBuf], + smem.gemm1AccColMax, acc, smem.gemm1AccColSum, + smem.gemm1WarpGrpBar); +#else + rescaleGemm1AccForNewRowMax_sync(warpRank, smem.xRowMax[idxXBuf], smem.xRowSum[idxXBuf], + smem.gemm1AccColMax, acc, smem.gemm1AccColSum); +#endif + auto& xBuf = smem.xBuf(idxXBuf); + + auto const descXBase = + gmma::makeMatDesc(nullptr, 0, SharedMem::XBuffer::Elem::rowBytes * 8, + gmma::getSwizzleMode(SharedMem::XBuffer::Elem{})) + .raw(); +#if CACHE_ELEM_ENUM == 0 + auto const descVBase = + gmma::makeMatDesc(nullptr, 0, SharedMem::VBuffer::Elem::rowBytes * 8, + gmma::getSwizzleMode(SharedMem::VBuffer::Elem{})) + .raw(); +#endif +#if SWAP_AB +//@fixme: to reduce code size, we can disable unroll and use double-buffer for LDSM in +// loadVTileTransposed. +#pragma unroll + for (uint32_t idxInstK = 0; idxInstK < gemm1NbGmmaInstK; idxInstK++) { +#if CACHE_ELEM_ENUM == 2 + Vec const fragA = + loadVTileTransposed(warpRank, laneId(), vBuf, idxInstK); +#if !defined(NDEBUG) && DBG_PRINT + if (threadIdx.x == 0) { + printf("fragA:\nidxInstK == %u\n", idxInstK); + } + smem.gemm1WarpGrpBar.arrive_and_wait(); + for (int m = 0; m < 2; m++) { + for (int w = 0; w < 4; w++) { + if (warpRank == w) { + if (laneId() == 0) { + printf(" warpRank = %u\n", warpRank); + } + __syncwarp(); + for (int a = 0; a < 2; a++) { + for (int b = 0; b < 8; b++) { + for (int c = 0; c < 2; c++) { + for (int d = 0; d < 4; d++) { + if (laneId() == b * 4 + d) { + for (int e = 0; e < 4; e++) { + auto const& elem4 = + reinterpret_cast<__nv_fp8_e4m3 const(&)[4]>(fragA[m](0, c)(a, 0)); + printf("%.2f, ", float(elem4[e])); + } + } + __syncwarp(); + } + } + if (laneId() == 0) { + printf("\n"); + } + __syncwarp(); + } + if (laneId() == 0 && a == 0) { + printf("----------------------\n"); + } + __syncwarp(); + } + } + smem.gemm1WarpGrpBar.arrive_and_wait(); + } + } +#endif +#endif + BoundedVal const kOffsetInGrains{grainsPerInstK * + idxInstK}; + auto const descX = + addAddr(descXBase, + &xBuf[kOffsetInGrains.template divBy().get()]( + 0, kOffsetInGrains.template mod().get())); +#if CACHE_ELEM_ENUM == 2 + gmma::fence(); +#endif +#pragma unroll + for (uint32_t idxInstM = 0; idxInstM < gemm1NbGmmaInstM; idxInstM++) { +#if CACHE_ELEM_ENUM == 0 + auto const descV = + addAddr(descVBase, &vBuf[idxInstM](kOffsetInGrains.get() * cacheElemsPerGrain, 0)); + gmma::mma_async_shmA( + reinterpret_cast( + acc(idxInstM, 0)), + descV, descX, true); +#elif CACHE_ELEM_ENUM == 2 + gmma::mma_async_regA( + reinterpret_cast( + acc(idxInstM, 0)), + reinterpret_cast(fragA[idxInstM]), descX, true); +#endif + } + gmma::commit_group(); + //@fixme: delay wait and consumption to next tile. Note that fragA must also persist until + // finish of + // gmma. + gmma::wait_group<0>(); + } +#else + auto const descVTBase = gmma::makeMatDesc(nullptr, 0, SharedMem::VTBuffer::rowBytes * 8, + gmma::getSwizzleMode(SharedMem::VTBuffer{})) + .raw(); + vtBar.produced.arrive_and_wait(); +// if (idxIter == 1 && threadIdx.x == 0) { +// printf("vtBuf:\n"); +// dbg::printArray2D<__nv_fp8_e4m3, true>(vtBuf); +// } +#pragma unroll + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { +#pragma unroll + for (uint32_t k = 0; k < gemm1NbGmmaInstK; k++) { + BoundedVal const kOffsetInGrains{grainsPerInstK * k}; + auto const descX = + addAddr(descXBase, + &xBuf[kOffsetInGrains.template divBy().get()]( + gmma::instM * m, + kOffsetInGrains.template mod().get())); + auto const descVT = + addAddr(descVTBase, + &vtBuf(0, kOffsetInGrains.template mod().get())); + gmma::mma_async_shmA( + reinterpret_cast(acc(m, 0)), + descX, descVT, true); + } + } + gmma::commit_group(); + //@fixme: delay wait and consumption to next tile. Note that fragA must also persist until + // finish of gmma. + gmma::wait_group<0>(); +#endif + if (idxIter == nbIters - 1) { + // gmma::wait_group should have already synchronized threads, so this may be unnecessary. + smem.gemm1WarpGrpBar.arrive_and_wait(); + assert(idxXBuf == idxVBuf); + if (isMultiBlockMode) { + ScratchMem const scratchMem{scratch, maxNbSubSeq * nbKHeads * batchSize, nbInputSeqSplit}; + uint32_t const idxSeq = nbKHeads * idxReq + idxHeadGrp; + uint32_t const idxAllSubSeq = maxNbSubSeq * idxSeq + idxSubSeq; + uint32_t const idxChunk = idxAllSubSeq * nbInputSeqSplit + idxInputSubSeq; + // save row max/sum + static_assert(ctaNbValidQHeads <= gmmaWarpsPerGrp * warp_size); + if (threadIdx.x < ctaNbValidQHeads) { + float const colMax = smem.gemm1AccColMax[threadIdx.x]; + float const colSum = smem.gemm1AccColSum[threadIdx.x]; + ScratchMem::SumMax sumMax; + sumMax.sum = colSum; + sumMax.max = colMax; + (scratchMem.rowSumMax() + idxChunk).template cast()[threadIdx.x] = + sumMax; + } + // compute scratch ptr for output writing + IOHead* const dst = (scratchMem.tokens() + idxChunk).template cast(); +#if SWAP_AB + finalizeAndWriteOut_sync(threadIdx.x, warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, + xvoScale, smem.gemm1WarpGrpBar, smem.gemm1AccColSum, + smem.gemm1AccColMax, nullptr); +#else + finalizeAndWriteOut_sync(warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale, + smem.gemm1AccColSum, 1, ctaNbValidTokens); +#endif + } else { + uint32_t const outOffset = + headGrpSize * (nbKHeads * (beamWidth * ctaInputTokBeg) + idxHeadGrp); + OutputHead* const dst = &output[outOffset]; + ShmQWiseVec const* attentionSinksVec = nullptr; + if (attentionSinks != nullptr) { + attentionSinksVec = + reinterpret_cast(attentionSinks + headGrpSize * idxHeadGrp); + } +#if SWAP_AB + finalizeAndWriteOut_sync(threadIdx.x, warpRank, dst, + smem.outSwizzleBuf(idxXBuf), acc, xvoScale, + smem.gemm1WarpGrpBar, smem.gemm1AccColSum, + smem.gemm1AccColMax, attentionSinksVec, nbKHeads); +#else + finalizeAndWriteOut_sync(warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale, + smem.gemm1AccColSum, nbKHeads, ctaNbValidTokens); +#endif + } + } + unused(xBar.consumed.arrive()); +#if SWAP_AB + unused(vBar.consumed.arrive()); +#else + unused(vtBar.consumed.arrive()); +#endif + } + } else { + // IO warps + static_assert(beamWidth == 1); +#if ENABLE_PDL + preExit(); +#endif +#if ENABLE_PDL == 1 + acqBulk(); +#endif + assert(warpIdx.z == 2); + uint32_t const newTokenPos = cacheSeqLen - 1; + if (warpIdx.x < nbQLdWarps) { + // load Q. Use register to load fp16 data and store fp8 to shared mem. + // @fixme: If register pressure is high and shared mem pressure is low, switch to TMA instead. + using QCvt = F16QToF8Converter; + static_assert(beamWidth == 1); +#if USE_INPUT_KV + TinyPtr const qData{ + qkv, headGrpSize * idxHeadGrp + (headGrpSize + 2) * nbKHeads * idxReq}; + constexpr bool isNeox = (ROPE_STYLE == 1); + constexpr uint32_t thrdsPerHead = mha::min(warp_size, divUp(headElems, 4U)); + uint32_t const lane = laneId(); + uint32_t const idxThrd = warpIdx.x * warp_size + lane; + uint32_t const idxThrdGrp = + (thrdsPerHead % 32 == 0 ? makeWarpUniform(this_warp(), idxThrd / thrdsPerHead) + : idxThrd / thrdsPerHead); + constexpr uint32_t nbThrdGrps = exactDiv(warp_size * nbQLdWarps, thrdsPerHead); + uint32_t const tid = idxThrd % thrdsPerHead; + smem.qBar.consumed.arrive_and_wait(); +#if ROPE_STYLE != 0 + auto const& ropeCosSinHead = + reinterpret_cast const&>(ropeCosSin[cacheSeqLen - 1]); + auto const cosSinPairs = loadHead(ropeCosSinHead, tid); +#endif +#if ENABLE_PDL == 2 + acqBulk(); +#endif +#pragma unroll + for (uint32_t iter = 0; iter < divUp(headGrpSize, nbThrdGrps); iter++) { + uint32_t const idxHead = nbThrdGrps * iter + idxThrdGrp; + if (idxHead >= headGrpSize) { + break; + } +#if ROPE_STYLE == 0 + auto const rotatedPairs = + loadHead(qData[idxHead], tid); +#else + auto const pairs = loadHead(qData[idxHead], tid); + auto const rotatedPairs = applyRoPE(pairs, cosSinPairs); +#endif + storeRotatedPairsForQ(smem.q, rotatedPairs, idxHead, tid); + } +#else + TinyPtr const qData{ + q, headGrpSize * (nbKHeads * (beamWidth * ctaInputTokBeg) + idxHeadGrp)}; +#if ENABLE_PDL == 2 + acqBulk(); +#endif + auto const f16QData = QCvt::load(threadIdx.x, qData, nbKHeads, ctaNbValidTokens); + + smem.qBar.consumed.arrive_and_wait(); + QCvt::store(threadIdx.x, smem.q, f16QData); +#endif + // the release semantics of arrive does not work for async consumers like gmma. additional + // fence is needed. + asm volatile("fence.proxy.async.shared::cta;\n"); + unused(smem.qBar.produced.arrive()); + } else if (warpIdx.x == nbQLdWarps) { // load k + KVTilePartLoader kTilePartLoader{true, nbKHeads, cacheList, idxReq, idxHeadGrp, +#if USE_PAGED_KV_CACHE +#if PAGED_KV_CACHE_LAYOUT == 1 + tensorMapVLLMK, +#else + tensorMap, +#endif + nbPages, smem.pages[0] +#else + tensorMap +#endif + }; + for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++) { + uint32_t const idxKTile = idxKTileInit + idxIter * nbSubSeq; + kTilePartLoader.loadPages(idxKTile); +#if USE_INPUT_KV || ENABLE_PDL == 2 +#if SPEC_DEC + bool const anyNewTokens = + (gemm0CtaTileNbTokens * (idxKTile + 1) > cacheSeqLen - inputSeqLen); +#else + bool const anyNewTokens = (gemm0CtaTileNbTokens * (idxKTile + 1) >= cacheSeqLen); +#endif + if (anyNewTokens) { +#if ENABLE_PDL == 2 + acqBulk(); +#endif +#if USE_INPUT_KV + static_assert(beamWidth == 1); + uint32_t const inputKHeadOffset = + headGrpSize * nbKHeads + idxHeadGrp + (headGrpSize + 2) * nbKHeads * idxReq; + IOHead const& inKHead = qkv[inputKHeadOffset]; + uint32_t const lane = laneId(); + float const rcpKScale = 1.F / kvCacheScale[0]; +#if ROPE_STYLE == 0 + constexpr bool isNeox = false; + auto const pairs = + loadHead(inKHead, lane) * rcpKScale; + Vec, decltype(pairs)::size> convertedPairs; + constexpr uint32_t nbElems = decltype(pairs)::Elem::size * decltype(pairs)::size; + reinterpret_cast&>(convertedPairs) = + convert(reinterpret_cast const&>(pairs)); + storeRotatedPairsForKV(kTilePartLoader.getHead(newTokenPos), + convertedPairs, lane); +#else + constexpr bool isNeox = (ROPE_STYLE == 1); + auto const pairs = loadHead(inKHead, lane) * rcpKScale; + auto const& ropeCosSinHead = + reinterpret_cast const&>(ropeCosSin[cacheSeqLen - 1]); + auto const cosSinPairs = loadHead(ropeCosSinHead, lane); + auto const rotatedPairs = applyRoPE(pairs, cosSinPairs); + storeRotatedPairsForKV(kTilePartLoader.getHead(newTokenPos), + rotatedPairs, lane); +#endif + static_assert(inputSeqLen == 1); + __syncwarp(); +#endif + } +#endif + for (uint32_t idxPart = 0; idxPart < cacheHeadNbParts; idxPart++) { + auto const idxKBuf = (idxIter * cacheHeadNbParts + idxPart) % SharedMem::nbKBuf; + auto& kBar = smem.kBar[idxKBuf]; + kBar.consumed.arrive_and_wait(); + if (warpElectSync()) { + kTilePartLoader.loadData(smem.k[idxKBuf], idxKTile, idxPart, kBar.produced); + } + __syncwarp(); + } + } + } else if (warpIdx.x == nbQLdWarps + 1) { // load v + KVTilePartLoader vTileLoader{false, nbKHeads, cacheList, idxReq, idxHeadGrp, +#if USE_PAGED_KV_CACHE +#if PAGED_KV_CACHE_LAYOUT == 1 + tensorMapVLLMV, +#else + tensorMap, +#endif + nbPages, smem.pages[1] +#else + tensorMap +#endif + }; + for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++) { + uint32_t const idxVTile = idxVTileInit + idxIter * nbSubSeq; + vTileLoader.loadPages(idxVTile); +#if USE_INPUT_KV || ENABLE_PDL == 2 +#if SPEC_DEC + bool const anyNewTokens = + (gemm0CtaTileNbTokens * (idxVTile + 1) > cacheSeqLen - inputSeqLen); +#else + bool const anyNewTokens = (gemm0CtaTileNbTokens * (idxVTile + 1) >= cacheSeqLen); +#endif + if (anyNewTokens) { +#if ENABLE_PDL == 2 + acqBulk(); +#endif +#if USE_INPUT_KV + static_assert(beamWidth == 1); + uint32_t const inputVHeadOffset = + (headGrpSize + 1) * nbKHeads + idxHeadGrp + (headGrpSize + 2) * nbKHeads * idxReq; + IOHead const& inVHead = qkv[inputVHeadOffset]; + uint32_t const lane = laneId(); + float const rcpVScale = 1.F / kvCacheScale[0]; + constexpr bool isNeox = false; + auto const pairs = + loadHead(inVHead, lane) * rcpVScale; + Vec, decltype(pairs)::size> convertedPairs; + constexpr uint32_t nbElems = decltype(pairs)::Elem::size * decltype(pairs)::size; + reinterpret_cast&>(convertedPairs) = + convert(reinterpret_cast const&>(pairs)); + static_assert(SPEC_DEC == 0); + storeRotatedPairsForKV(vTileLoader.getHead(newTokenPos), + convertedPairs, lane); + __syncwarp(); +#endif + } +#endif + + uint32_t const idxVBuf = idxIter % SharedMem::nbVBuf; + auto& vBar = smem.vBar[idxVBuf]; + vBar.consumed.arrive_and_wait(); + if (warpElectSync()) { +#pragma unroll + for (uint32_t idxPart = 0; idxPart < cacheHeadNbParts; idxPart++) { + vTileLoader.loadData(smem.vBuf(idxVBuf)[idxPart], idxVTile, idxPart, vBar.produced); + } + } + __syncwarp(); + } + } + } + __syncthreads(); + uint32_t const nbBarriers = &smem.gemm1WarpGrpBar - &smem.qBar.produced + 1; + uint32_t const tid = + threadIdx.x + blockDim.x * threadIdx.y + blockDim.x * blockDim.y * threadIdx.z; + assert(nbBarriers <= blockDim.x * blockDim.y * blockDim.z); + if (tid < nbBarriers) { + (&smem.qBar.produced)[tid].~CtaBarrier(); + } + if (!isMultiBlockMode) { + return; + } + bool& smemIsLastCta = smem.isLastCta; + if (threadIdx.x == gemm1NbThrds - 1U && threadIdx.z == 0) { + uint32_t const lastOld = nbSubSeq - 1; + ScratchMem const scratchMem{scratch, maxNbSubSeq * nbKHeads * batchSize, nbInputSeqSplit}; + uint32_t const idxSeq = nbKHeads * idxReq + idxHeadGrp; + uint32_t old; + uint32_t const idxSemaphore = idxSeq * nbInputSeqSplit + idxInputSubSeq; + auto const pSemaphore = &semaphores[idxSemaphore]; + asm volatile("atom.acq_rel.gpu.global.inc.u32 %0, [%1], %2;\n" + : "=r"(old) + : "l"(pSemaphore), "r"(lastOld)); + smemIsLastCta = (old == lastOld); + } + { + assert(dynamicSmemSize() >= sizeof(MultiBlockSMem)); +#ifndef __CUDACC_RTC__ + assert(sizeof(MultiBlockSMem) < offsetof(SharedMem, isLastCta)); +#endif + auto& smem = *reinterpret_cast(&smemByteBuf[0]); + assert(blockDim.x >= MultiBlockSMem::nbBuf); + constexpr uint32_t nbMathWarps = gemm0NbWarps + gemm1NbWarps; + + static_assert(nbWarps >= MultiBlockSMem::nbBuf); + if (wid < MultiBlockSMem::nbBuf) { + if (warpElectSync()) { + smem.barriers[wid].initialize(isHeadPadded ? warp_size : 1U, nbMathWarps * warp_size); + smem.barriers[wid].consumed.arrive(nbMathWarps * warp_size); + } + } + __syncthreads(); + + if (!smemIsLastCta) { + return; + } + if (wid < nbMathWarps) { + constexpr uint32_t headsPerWarp = divUp(ctaNbValidQHeads, nbMathWarps); + using Acc = Vec; + + struct HeadState { + Acc acc; + float sum; + float max; + }; + + Vec states{}; + for (auto& s : states.data) { + s.max = safeInitRowMax; + } + uint32_t const lane = laneId(); + for (uint32_t idxBlock = 0; idxBlock < nbSubSeq; idxBlock++) { + uint32_t const idxBuf = idxBlock % MultiBlockSMem::nbBuf; + auto& bar = smem.barriers[idxBuf]; + bar.produced.wait_parity(idxBlock / MultiBlockSMem::nbBuf % 2 != 0); + for (uint32_t i = 0; i < headsPerWarp; i++) { + uint32_t const idxHead = wid + nbMathWarps * i; + if ((ctaNbValidQHeads % nbMathWarps != 0) && (idxHead >= ctaNbValidQHeads)) { + break; + } + HeadState& state = states[i]; + auto const sumMax = smem.rowSumMax[idxBuf][idxHead]; + auto const data = convert(reinterpret_cast&>( + smem.tokens[idxBuf][idxHead][Acc::size * lane])); + if (sumMax.max > state.max) { + float const scale = expf(state.max - sumMax.max); + state.max = sumMax.max; + state.sum = state.sum * scale + sumMax.sum; + state.acc = state.acc * scale + data * sumMax.sum; + } else { + float const scale = expf(sumMax.max - state.max); + state.sum = state.sum + sumMax.sum * scale; + state.acc = state.acc + data * (sumMax.sum * scale); + } + } + unused(bar.consumed.arrive()); + } + // Add the attention sinks. + if (attentionSinks != nullptr) { + for (uint32_t i = 0; i < headsPerWarp; i++) { + uint32_t const idxHead = wid + nbMathWarps * i; + float sink = + expf(attentionSinks[mha::min(idxHead, headGrpSize - 1) + idxHeadGrp * headGrpSize] - + states[i].max); + states[i].sum += sink; + } + } + __syncthreads(); + uint32_t const outOffset = + headGrpSize * (nbKHeads * (beamWidth * ctaInputTokBeg) + idxHeadGrp); + auto const dst = &output[outOffset]; + for (uint32_t i = 0; i < headsPerWarp; i++) { + uint32_t const idxHead = wid + nbMathWarps * i; + if ((ctaNbValidQHeads % nbMathWarps != 0) && (idxHead >= ctaNbValidQHeads)) { + break; + } +#if SPEC_DEC + uint32_t const idxToken = idxHead / headGrpSize; + if (idxToken >= ctaNbValidTokens) { + break; + } + uint32_t const tokenPad = headGrpSize * (nbKHeads - 1); + uint32_t const idxDstHead = idxHead + idxToken * tokenPad; +#else + uint32_t const idxDstHead = idxHead; +#endif + auto const& s = states[i]; + auto const outData = convert(s.acc * (1.f / s.sum)); + if (Acc::size * lane < validElemsPerHead) { + reinterpret_cast&>(dst[idxDstHead][Acc::size * lane]) = + outData; + } + } + } else if (wid < nbMathWarps + MultiBlockSMem::nbIOWarps) { + static_assert(MultiBlockSMem::nbIOWarps <= MultiBlockSMem::nbBuf); + ScratchMem const scratchMem{scratch, maxNbSubSeq * nbKHeads * batchSize, nbInputSeqSplit}; + uint32_t const idxSeq = nbKHeads * idxReq + idxHeadGrp; + uint32_t const initIdxBlock = wid - nbMathWarps; + // each warp loads data for a block + for (uint32_t idxBlock = initIdxBlock; idxBlock < nbSubSeq; + idxBlock += MultiBlockSMem::nbIOWarps) { + uint32_t const idxAllSubSeq = maxNbSubSeq * idxSeq + idxBlock; + uint32_t const idxChunk = idxAllSubSeq * nbInputSeqSplit + idxInputSubSeq; + uint32_t const idxBuf = idxBlock % MultiBlockSMem::nbBuf; + auto& bar = smem.barriers[idxBuf]; + bar.consumed.wait_parity(idxBlock / MultiBlockSMem::nbBuf % 2 != 0); + auto const lane = laneId(); +#pragma unroll + for (uint32_t iter = 0; iter < divUp(ctaNbValidQHeads, warp_size); iter++) { + uint32_t const i = iter * warp_size + lane; + if (ctaNbValidQHeads % warp_size != 0 && i >= ctaNbValidQHeads) { + break; + } + ldgsts::copyAsync( + &smem.rowSumMax[idxBuf][i], &scratchMem.rowSumMax()[idxChunk][i]); + } + ldgsts::barArrive(bar.produced, false); + if constexpr (isHeadPadded) { + static_assert(grainsPerPaddedInputHead <= warp_size); + constexpr uint32_t headsPerIter = exactDiv(warp_size, grainsPerPaddedInputHead); + constexpr uint32_t nbIters = divUp(ctaNbValidQHeads, headsPerIter); + constexpr uint32_t nbWholeIters = ctaNbValidQHeads / headsPerIter; +#pragma unroll + for (uint32_t i = 0; i < nbIters; i++) { + uint32_t const idxHead = + headsPerIter * i + + BoundedVal{lane}.template divBy().get(); + uint32_t const idxGrain = + BoundedVal{lane}.template mod().get(); + if (i < nbWholeIters || idxHead < ctaNbValidQHeads) { + constexpr uint32_t nbElemsPerGrain = + exactDiv(grainBytes, sizeof(MultiBlockSMem::Elem)); + auto const dst = &smem.tokens[idxBuf][idxHead][nbElemsPerGrain * idxGrain]; + auto const src = + idxGrain < grainsPerIOHead + ? &scratchMem.tokens()[idxChunk][idxHead][nbElemsPerGrain * idxGrain] + : nullptr; + ldgsts::copyAsync(dst, src, idxGrain < grainsPerIOHead ? grainBytes : 0U); + } + } + ldgsts::barArrive(bar.produced, true); + } else { + if (warpElectSync()) { + tma::loadLinearAsync(&smem.tokens[idxBuf], &scratchMem.tokens()[idxChunk], + sizeof(smem.tokens[idxBuf]), bar.produced); + arrive_tx(bar.produced, sizeof(smem.tokens[idxBuf]), 1); + } + } + } + __syncthreads(); + uint32_t const idxBar = tid - warp_size * nbMathWarps; + if (idxBar < MultiBlockSMem::nbBuf * 2) { + reinterpret_cast(&smem.barriers[0])[idxBar].~CtaBarrier(); + } + } + } +#else +#if GENERATE_CUBIN + static_assert("This kernel is for Hopper only"); +#else + asm volatile("trap;\n"); +#endif +#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && BEAM_WIDTH == 1 +} + +#if CACHE_ELEM_ENUM == 0 || CACHE_ELEM_ENUM == 2 +template +__device__ inline typename F16QToF8Converter::RegData +F16QToF8Converter::load(uint32_t tid, TinyPtr const& src, + uint32_t const nbKHeads /*for beam search only*/, + uint32_t nbTokens) { +#if !(SPEC_DEC) + assert(nbTokens == 1); + nbTokens = 1; +#endif + typename F16QToF8Converter::RegData dst; +#pragma unroll + for (uint32_t iter = 0; iter < nbIters; iter++) { + uint32_t const idxGrain = nbThrds * iter + tid; + if (idxGrain >= totalGrains) { + break; + } +#if SPEC_DEC + uint32_t const idxToken = idxGrain / grainsPerPaddedInputQHeadGrp; + uint32_t const tokenPad = grainsPerPaddedInputQHeadGrp * (nbKHeads - 1); + uint32_t offsetInGrains = idxGrain + tokenPad * idxToken; + static_assert(beamWidth == 1); +#else + uint32_t const idxBeam = beamWidth == 1 ? 0 : idxGrain / grainsPerPaddedInputQHeadGrp; + uint32_t const beamPad = grainsPerPaddedInputQHeadGrp * (nbKHeads - 1); + uint32_t offsetInGrains = idxGrain + beamPad * idxBeam; +#endif + bool isGrainInBound = true; + if constexpr (isHeadPadded) { + uint32_t const idxGrainInsideHead = offsetInGrains % grainsPerPaddedInputHead; + offsetInGrains = + offsetInGrains / grainsPerPaddedInputHead * grainsPerIOHead + idxGrainInsideHead; + isGrainInBound = (idxGrainInsideHead < grainsPerIOHead); + } +#if SPEC_DEC + isGrainInBound = isGrainInBound && (idxToken < nbTokens); +#endif + LdGrain const srcGrain = + isGrainInBound ? src.template cast()[offsetInGrains] : LdGrain{}; + static_assert(inputElemSize == 2); + auto const& fp16Data = + reinterpret_cast const&>(srcGrain); + dst[iter] = idxGrain % grainsPerPaddedInputHead < grainsPerIOHead + ? fp16Data + : mha::decay_t{}; + } + return dst; +} + +template +__device__ inline void F16QToF8Converter::store( + uint32_t tid, SharedMem::QBuffer& dst, + F16QToF8Converter::RegData const& data) { +#pragma unroll + for (uint32_t iter = 0; iter < nbIters; iter++) { + uint32_t const idxGrain = nbThrds * iter + tid; + if (idxGrain >= totalGrains) { + break; + } +#if CACHE_ELEM_ENUM == 0 + static_assert(inputElemSize == cacheElemSize); + ShmVec const& shmData = data[iter]; + uint32_t const r = idxGrain / grainsPerPaddedInputHead; + BoundedVal const c = {idxGrain % grainsPerPaddedInputHead}; + + dst[c.template divBy().get()].template at( + r, c.template mod().get()) = reinterpret_cast(shmData); +#else + auto const& fp16Data = data[iter]; + ShmVec shmData; +#pragma unroll + for (uint32_t i = 0; i < fp16Data.size; i++) { + shmData[i] = CacheElem{fp16Data[i]}; + } + uint32_t const dstIdxGrain = idxGrain / 2; + uint32_t const dstIdxHalfGrain = idxGrain % 2; + constexpr uint32_t grainsPerCacheHead = exactDiv(paddedCacheHeadBytes, grainBytes); + uint32_t const r = dstIdxGrain / grainsPerCacheHead; + BoundedVal const c = {dstIdxGrain % grainsPerCacheHead}; + reinterpret_cast&>( + dst[c.template divBy().get()].template at( + r, c.template mod().get()))[dstIdxHalfGrain] = shmData; +#endif + } +} +#endif + +__device__ inline KVTilePartLoader::KVTilePartLoader(bool isK, uint32_t nbKHeads, + KVCacheList const& cacheList, + uint32_t idxReq, uint32_t idxHeadGrp, + CUtensorMap const& tensorMap +#if USE_PAGED_KV_CACHE + , + uint32_t nbPages, + Vec& pageBuf +#endif + ) + : nbKHeads{nbKHeads}, + cacheList{cacheList}, + idxReq{idxReq}, + idxHeadGrp{idxHeadGrp}, + tensorMap{tensorMap} +#if USE_PAGED_KV_CACHE + , + nbPages{nbPages}, + pages{pageBuf} +#if PAGED_KV_CACHE_LAYOUT == 1 + , + baseOffset{idxReq * cacheList.maxNbPagesPerSeq} +#else + , + baseOffset{((idxReq * beamWidth) * 2 + (isK ? 0 : 1)) * cacheList.maxNbPagesPerSeq} +#endif +#else + , + baseOffset{(idxReq * beamWidth) * 2 + (isK ? 0 : 1)} +#endif +{ +} + +// tensorMap is for one whole page ([nbKHeads*tokensPerPage][headElems]) or whole cache +template +__device__ inline void KVTilePartLoader::loadData( + Array2D& dst, + uint32_t idxTile, uint32_t idxPart, CtaBarrier& bar) { + static_assert(nbTokens == gemm0CtaTileNbTokens); +#if USE_PAGED_KV_CACHE + assert(idxTile == idxTileRef); + if constexpr (nbTokens < tokensPerPage) { + assert(nbPagesPerTile == 1); + uint32_t const offset = nbTokens * (idxTile % exactDiv(tokensPerPage, nbTokens)); +#if PAGED_KV_CACHE_LAYOUT == 1 + tma::loadAsync(&dst, tensorMap, + DimsLE<4>{partElems * idxPart, idxHeadGrp, offset, (uint32_t)pages[0]}, bar); +#else + tma::loadAsync(&dst, tensorMap, + DimsLE<4>{partElems * idxPart, offset, idxHeadGrp, (uint32_t)pages[0]}, bar); +#endif + } else { +#pragma unroll + for (uint32_t i = 0; i < nbPagesPerTile; i++) { +#if PAGED_KV_CACHE_LAYOUT == 1 + tma::loadAsync(&dst(tokensPerPage * i, 0), tensorMap, + DimsLE<4>{partElems * idxPart, idxHeadGrp, 0, (uint32_t)pages[i]}, bar); +#else + tma::loadAsync(&dst(tokensPerPage * i, 0), tensorMap, + DimsLE<4>{partElems * idxPart, 0, idxHeadGrp, (uint32_t)pages[i]}, bar); +#endif + } + } +#else + tma::loadAsync(&dst, tensorMap, + DimsLE<4>{partElems * idxPart, nbTokens * idxTile, idxHeadGrp, baseOffset}, bar); +#endif +} + +__device__ inline void KVTilePartLoader::loadPages(uint32_t idxTile) { +#if USE_PAGED_KV_CACHE + uint32_t const idxPageBeg = gemm0CtaTileNbTokens >= tokensPerPage + ? nbPagesPerTile * idxTile + : idxTile / exactDiv(tokensPerPage, gemm0CtaTileNbTokens); +#pragma unroll + for (uint32_t i = 0; i < nbPagesPerTile; i++) { + uint32_t const idxPage = idxPageBeg + i; + auto const page = + idxPage < nbPages ? cacheList.kvCachePageList[baseOffset + idxPage] : kBAD_PAGE_INDEX; + if (warpElectSync()) { + pages[i] = page; + } + } + idxTileRef = idxTile; + __syncwarp(); +#endif +} + +__device__ inline GMemKVCacheHead& KVTilePartLoader::getHead(uint32_t pos) { + constexpr uint32_t nbTokens = gemm0CtaTileNbTokens; +#if USE_PAGED_KV_CACHE +#if PAGED_KV_CACHE_LAYOUT == 1 + // Raise a runtime error indicating not implemented + assert(false && "KVTilePartLoader::getHead is not implemented for PAGED_KV_CACHE_LAYOUT == 1"); + __trap(); +#else + uint32_t const idxTile = pos / nbTokens; + assert(idxTile == idxTileRef); + uint32_t const offset = pos % tokensPerPage; + return cacheList + .pool[tokensPerPage * (nbKHeads * pages[pos % nbTokens / tokensPerPage] + idxHeadGrp) + + offset]; +#endif +#else + // shape: KVCacheHead[batchSize][beamWidth][2][nbKHeads][capacity] + return cacheList.data[cacheList.capacity * (baseOffset * nbKHeads + idxHeadGrp) + pos]; +#endif +} + +#if SWAP_AB +#if SPEC_DEC +__device__ inline void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec, +#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE + int32_t tok0WinBeg, +#endif + uint32_t cacheSeqLen, uint32_t idxTile, uint32_t warpRank) { + constexpr uint32_t tileSize = gemm0CtaTileNbTokens; + static_assert(SPEC_Q_SEQ_LEN <= sizeof(MaskType) * 8, "not implemented"); + + assert(cacheSeqLen >= SPEC_Q_SEQ_LEN); + uint32_t const maskStartRow = cacheSeqLen - SPEC_Q_SEQ_LEN; + uint32_t const tileStartRow = tileSize * idxTile; + if (tileStartRow + tileSize < maskStartRow) { + return; + } + + uint32_t const idxInQuad = laneId() % 4; + uint32_t const idxQuad = laneId() / 4; + +#pragma unroll + for (uint32_t n = 0; n < acc.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + uint32_t const col = GmmaAccCoreMat::cols * (4 * n + idxInQuad) + j; + uint32_t const maskCol = col / headGrpSize; + MaskType const bit_mask = (1ULL << (maskCol + 1)) - 1; + +#pragma unroll + for (uint32_t m = 0; m < acc.rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + uint32_t const row = gmma::instM * m + gmma::instM / 4 * warpRank + 8 * i + idxQuad; + uint32_t const globalRow = tileStartRow + row; + if (globalRow >= cacheSeqLen) { + acc(m, n)(i, j) = safeInitRowMax; + continue; + } + if (globalRow >= maskStartRow) { + uint32_t const maskRow = globalRow - maskStartRow; + if ((bit_mask >> maskRow) == 0) { + acc(m, n)(i, j) = safeInitRowMax; + } + } + } + } + } + } +} +#endif // SPEC_DEC + +// smemColMax is persistent across multiple iterations +__device__ inline RegColWiseVec computeWarpGrpColMax_sync(CtaBarrier& warpGrpBar, + ShmQWiseVec& smemColMax, + Gemm0Acc const& src) { + auto colMax = RegColWiseVec::filled(Vec::filled(safeInitRowMax)); +#pragma unroll + for (uint32_t n = 0; n < src.cols; n++) { + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { +#pragma unroll + for (uint32_t m = 0; m < src.rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + colMax[n][j] = (m == 0 && i == 0) ? src(m, n)(i, j) : fmax(colMax[n][j], src(m, n)(i, j)); + } + } + } + } + +#pragma unroll + for (uint32_t xorMask = 16; xorMask > 2; xorMask /= 2) { +#pragma unroll + for (uint32_t n = 0; n < src.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < 2; j++) { + auto& x = colMax[n][j]; + x = fmax(x, __shfl_xor_sync(~0U, x, xorMask)); + } + } + } + + uint32_t const lane = laneId(); + if (lane < 4) { +#pragma unroll + for (uint32_t n = 0; n < src.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < 2; j++) { + atomicMax(&smemColMax[8 * n + 2 * lane + j], colMax[n][j]); + } + } + } + warpGrpBar.arrive_and_wait(); + uint32_t const idxInQuad = lane % 4; + +#pragma unroll + for (uint32_t n = 0; n < src.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + assert(colMax[n][j] <= smemColMax[8 * n + 2 * idxInQuad + j]); + colMax[n][j] = smemColMax[8 * n + 2 * idxInQuad + j]; + } + } + warpGrpBar.arrive_and_wait(); + return colMax; +} + +__device__ inline RegColWiseVec loadShmColWiseVecWithDup(ShmQWiseVec const& smemVec) { + RegColWiseVec ret; + constexpr uint32_t nbThrdsPerInstNBase = exactDiv(gmma::instNBase, GmmaAccCoreMat::cols); + auto const idx = laneId() % nbThrdsPerInstNBase; +#pragma unroll + for (uint32_t i = 0; i < exactDiv(ShmQWiseVec::size, gmma::instNBase); i++) { + static_assert(nbThrdsPerInstNBase * RegColWiseVec::size == + exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)); + ret[i] = reinterpret_cast, + exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)> const&>( + smemVec)[i * nbThrdsPerInstNBase + idx]; + } + return ret; +} + +__device__ inline RegColWiseVec loadGmemColWiseVecWithDup(ShmQWiseVec const& gmemVec, + uint32_t bound) { + RegColWiseVec ret; + constexpr uint32_t nbThrdsPerInstNBase = exactDiv(gmma::instNBase, GmmaAccCoreMat::cols); + auto const idx = laneId() % nbThrdsPerInstNBase; +#pragma unroll + for (uint32_t i = 0; i < exactDiv(ShmQWiseVec::size, gmma::instNBase); i++) { + static_assert(nbThrdsPerInstNBase * RegColWiseVec::size == + exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)); + ret[i] = reinterpret_cast, + exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)> const&>( + gmemVec)[mha::min(i * nbThrdsPerInstNBase + idx, bound)]; + } + return ret; +} + +__device__ inline void warpGrpApplyMask(uint32_t warpRank, Gemm0Acc& acc, uint32_t validRowBeg, + uint32_t validRowEnd) { + uint32_t const idxInQuad = laneId() % 4; + uint32_t const idxQuad = laneId() / 4; +#pragma unroll + for (uint32_t m = 0; m < acc.rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + uint32_t const row = 64 * m + 16 * warpRank + 8 * i + idxQuad; + if (row >= validRowBeg && row < validRowEnd) { + continue; + } +#pragma unroll + for (uint32_t n = 0; n < acc.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + acc(m, n)(i, j) = safeInitRowMax; + } + } + } + } +} + +__device__ inline void warpGrpOnlineSoftmax(Gemm0Acc& acc, RegColWiseVec const& colMax) { +#pragma unroll + for (uint32_t n = 0; n < acc.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + float const maxVal = colMax[n][j]; + float const bias = maxVal * log2e; +#pragma unroll + for (uint32_t m = 0; m < acc.rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + float& elem = acc(m, n)(i, j); + assert(maxVal >= elem); + elem = exp2f(elem * log2e - bias); + } + } + } + } +} + +__device__ inline RegColWiseVec computeWarpColSum(Gemm0Acc& src) { + auto colSum = RegColWiseVec::filled(Vec::filled(0)); +#pragma unroll + for (uint32_t n = 0; n < src.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { +#pragma unroll + for (uint32_t m = 0; m < src.rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + colSum[n][j] = (m == 0 && i == 0) ? src(m, n)(i, j) : colSum[n][j] + src(m, n)(i, j); + } + } + } + } + +#pragma unroll + for (uint32_t xorMask = 16; xorMask > 2; xorMask /= 2) { +#pragma unroll + for (uint32_t n = 0; n < src.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + auto& x = colSum[n][j]; + x += __shfl_xor_sync(~0U, x, xorMask); + } + } + } + return colSum; +} + +__device__ inline void storeGemm0AccToShm(uint32_t warpRank, uint32_t lane, + SharedMem::XBuffer& smemX, CtaBarrier& barConsumed, + Gemm0Acc const& acc) { +#if CACHE_ELEM_ENUM == 0 + using F16Acc = Array2D, Gemm0Acc::rows, Gemm0Acc::cols>; + F16Acc f16Acc; + reinterpret_cast&>(f16Acc) = + convert(reinterpret_cast const&>(acc)); + static_assert(Gemm0Acc::size == 1 || Gemm0Acc::size % 2 == 0); + uint32_t const idxHalf = lane / 16; + uint32_t const idxInHalf = lane % 16; + uint32_t const idxOctInsideHalf = idxInHalf / 8; + uint32_t const idxRowInsideOct = lane % 8; + uint32_t const warpBaseC = 16 * warpRank; + auto const toAccCoords = [](uint32_t const idxAccCoreMat) -> std::pair { + uint32_t const accR = idxAccCoreMat / Gemm0Acc::cols; + uint32_t const accC = idxAccCoreMat % Gemm0Acc::cols; + return {accR, accC}; + }; + auto const getDstAddr = [&](uint32_t idxAccCoreMat) -> LdGrain* { + auto const [accR, accC] = toAccCoords(idxAccCoreMat); + static_assert(sizeof(MathElem) * gemm0CtaTileNbTokens == xPartBytes); + uint32_t const idxPart = 0; + uint32_t const dstR = accC * 8 + idxRowInsideOct; + uint32_t const dstC = + exactDiv(gmma::instM * accR + warpBaseC + 8 * idxOctInsideHalf, cacheElemsPerGrain); + assert(dstC / exactDiv(xPartBytes, grainBytes) == idxPart); + return &smemX[idxPart].template at(dstR, dstC); + }; + auto const getAccData = [&](uint32_t idxAccCoreMat) { + auto const [accR, accC] = toAccCoords(idxAccCoreMat); + return f16Acc(accR, accC); + }; + + barConsumed.arrive_and_wait(); +#pragma unroll + for (uint32_t iter = 0; iter < Gemm0Acc::size / 2; iter++) { + auto const dstAddr = getDstAddr(iter * 2 + idxHalf); + Vec const data[2] = {getAccData(iter * 2), getAccData(iter * 2 + 1)}; + stmatrix(dstAddr, reinterpret_cast(data)); + } + if constexpr (Gemm0Acc::size % 2 != 0) { + auto const dstAddr = lane < 16 ? getDstAddr(Gemm0Acc::size - 1) : nullptr; + stmatrix(dstAddr, getAccData(Gemm0Acc::size - 1)); + } +#elif CACHE_ELEM_ENUM == 2 + using F8Acc = Array2D; + F8Acc f8Acc; +#pragma unroll + for (uint32_t i = 0; i < acc.rows; i++) { +#pragma unroll + for (uint32_t j = 0; j < acc.cols; j++) { + auto const& core = acc(i, j); + static_assert(mha::is_same_v); + Vec const f8Data = { + __nv_cvt_float2_to_fp8x2(float2{core(0, 0), core(1, 0)}, __NV_SATFINITE, __NV_E4M3), + __nv_cvt_float2_to_fp8x2(float2{core(0, 1), core(1, 1)}, __NV_SATFINITE, __NV_E4M3)}; + f8Acc(i, j) = reinterpret_cast(f8Data); + } + } + + if constexpr (F8Acc::size == 4 || F8Acc::size == 2 || F8Acc::size == 1) { + LdGrain* dst = nullptr; + if (F8Acc::size == 4 || lane < 8 * F8Acc::size) { + uint32_t const idxCore = lane / 8; + uint32_t const srcRow = idxCore / F8Acc::cols; + uint32_t const srcCol = idxCore % F8Acc::cols; + uint32_t const dstCoreRow = lane % 8; + uint32_t const dstRow = srcCol * 8 + dstCoreRow; + BoundedVal const dstCol{ + srcRow * 4 + warpRank}; + dst = &smemX[dstCol.template divBy().get()].template at( + dstRow, dstCol.template mod().get()); + } + barConsumed.arrive_and_wait(); + stmatrix(dst, reinterpret_cast const&>(f8Acc)); + } else { + // we need to use loops + assert(false); + trap(); + } +#endif +} + +#else + +__device__ inline RegRowWiseVec warpRowWiseReduce(RegRowWiseVec const& init, Gemm0Acc const& src, + float (*op)(float, float)) { + RegRowWiseVec vec = init; +#pragma unroll + for (uint32_t m = 0; m < src.rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { +#pragma unroll + for (uint32_t n = 0; n < src.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + // @fixme: check if compiler is reordering these op to hide latency. + vec[m][i] = op(vec[m][i], src(m, n)(i, j)); + } + } + } + } + +#pragma unroll + for (uint32_t xorMask = 2; xorMask != 0; xorMask /= 2) { +#pragma unroll + for (uint32_t m = 0; m < src.rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + auto& x = vec[m][i]; + x = op(x, __shfl_xor_sync(~0U, x, xorMask)); + } + } + } + return vec; +} + +__device__ inline RegRowWiseVec computeWarpGrpRowMax_sync(uint32_t warpRank, + ShmQWiseVec& smemRowMax, + Gemm0Acc const& src) { + assert(warpRank < 4); + RegRowWiseVec const init = loadShmRowWiseVecWithDup(warpRank, smemRowMax); + RegRowWiseVec rowMax = warpRowWiseReduce(init, src, fmax); + + storeShmRowWiseVec(warpRank, smemRowMax, rowMax); + __syncwarp(); + return rowMax; +} + +#if SPEC_DEC +__device__ inline void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec, +#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE + int32_t tok0WinBeg, +#endif + uint32_t cacheSeqLen, uint32_t idxTile, uint32_t warpRank) { + constexpr uint32_t tileSize = gemm0CtaTileNbTokens; + auto const inputSeqLen = specDec.inputSeqLen; + auto const idxInputSubSeq = specDec.idxInputSubSeq; + constexpr uint64_t fullMask = ~uint64_t{0}; + static_assert(tileSize == sizeof(fullMask) * 8); +#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE + uint32_t const ctaTokOffset = inputTokensPerCta * idxInputSubSeq; + Range const tileRange = {tileSize * idxTile, tileSize * idxTile + tileSize}; + Range const maxMaskOutRange = {0, mha::max(0, tok0WinBeg) + (inputTokensPerCta - 1)}; + bool const ctaNeedBegMask = tileRange.beg < maxMaskOutRange.end; + assert(ctaNeedBegMask == overlap(tileRange, maxMaskOutRange)); + int32_t const tok0NbMaskOut = int32_t(tok0WinBeg) - int32_t(tileSize * idxTile); +#else + constexpr bool ctaNeedBegMask = false; + uint64_t const begMask = fullMask; + int32_t const tok0NbMaskOut = -2147483648; +#endif + uint32_t const offset = tileSize * idxTile; + uint32_t const nbValidCols = mha::min(offset < cacheSeqLen ? cacheSeqLen - offset : 0U, tileSize); + bool const ctaNeedEndMask = (nbValidCols < tileSize); + bool const ctaNeedSpecDecMask = specDec.needMask(idxTile, 0); + bool const needMask = ctaNeedBegMask || ctaNeedEndMask || ctaNeedSpecDecMask; + if (!needMask) { + return; + } + static_assert(tileSize == 64, "not implemented"); + auto const endMask = fullMask >> (tileSize - nbValidCols); + + uint32_t const idxInQuad = laneId() % 4; + uint32_t const idxQuad = laneId() / 4; +#pragma unroll + for (uint32_t m = 0; m < acc.rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + uint32_t const row = gmma::instM * m + gmma::instM / 4 * warpRank + 8 * i + idxQuad; + uint32_t const idxQTokInCta = row / headGrpSize; + bool const isQTokValid = + (headGrpSize * inputTokensPerCta == ctaNbQHeads) || (idxQTokInCta < inputTokensPerCta); + auto const specDecMask = (isQTokValid && specDec.needMask(idxTile, idxQTokInCta)) + ? specDec.loadTileMaskRow(idxTile, idxQTokInCta) + : SpecDec::TileMaskRow{~0U, ~0U}; +#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE + int32_t const begNbMaskOut = tok0NbMaskOut + int32_t(idxQTokInCta); + uint64_t const begMask = (begNbMaskOut > 0 ? fullMask << begNbMaskOut : fullMask); +#else + uint64_t const begMask = fullMask; +#endif + auto const mask = begMask & endMask & reinterpret_cast(specDecMask); + if (mask == ~uint64_t{0}) { + continue; + } +#if DBG_PRINT + if (idxInQuad == 0) { + printf("mask at row %d: %lx\n", row, mask); + } +#endif +#pragma unroll + for (uint32_t n = 0; n < acc.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + uint32_t const col = GmmaAccCoreMat::cols * (4 * n + idxInQuad) + j; + assert((col < nbValidCols) == bool(endMask & (1ULL << col))); + if ((mask & (1ULL << col)) == 0) { + acc(m, n)(i, j) = safeInitRowMax; + } + } + } + } + } +} +#else +__device__ inline void warpGrpApplyMask(Gemm0Acc& acc, uint32_t validColBeg, uint32_t validColEnd) { + uint32_t const idxInQuad = laneId() % 4; +#pragma unroll + for (uint32_t n = 0; n < acc.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + uint32_t const col = GmmaAccCoreMat::cols * (4 * n + idxInQuad) + j; + if (col >= validColBeg && col < validColEnd) { + continue; + } +#pragma unroll + for (uint32_t m = 0; m < acc.rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + acc(m, n)(i, j) = safeInitRowMax; + } + } + } + } +} +#endif + +__device__ inline void warpGrpOnlineSoftmax(Gemm0Acc& acc, RegRowWiseVec const& rowMax) { +#pragma unroll + for (uint32_t m = 0; m < acc.rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + float const maxVal = rowMax[m][i]; + float const bias = maxVal * log2e; +#pragma unroll + for (uint32_t n = 0; n < acc.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + float& elem = acc(m, n)(i, j); + assert(maxVal >= elem); + elem = exp2f(elem * log2e - bias); + } + } + } + } +} + +__device__ inline RegRowWiseVec computeWarpRowSum(Gemm0Acc& src) { + return warpRowWiseReduce(RegRowWiseVec{}, src, [](float a, float b) { return a + b; }); +} + +__device__ inline RegRowWiseVec loadShmRowWiseVecWithDup(uint32_t warpRank, + ShmQWiseVec const& smemVec) { + RegRowWiseVec vec; + uint32_t const idxQuad = laneId() / 4; +#pragma unroll + for (uint32_t m = 0; m < RegRowWiseVec::size; m++) { +#pragma unroll + for (uint32_t i = 0; i < RegRowWiseVec::Elem::size; i++) { + vec[m][i] = smemVec[gmma::instM * m + gmma::instM / 4 * warpRank + 8 * i + idxQuad]; + } + } + return vec; +} + +__device__ void storeShmRowWiseVec(uint32_t warpRank, ShmQWiseVec& smemVec, + RegRowWiseVec const& regVec) { + uint32_t const lane = laneId(); + uint32_t const idxQuad = lane / 4; + uint32_t const idxInQuad = lane % 4; + bool const enable = (idxInQuad == 0); +#pragma unroll + for (uint32_t m = 0; m < RegRowWiseVec::size; m++) { +#pragma unroll + for (uint32_t i = 0; i < RegRowWiseVec::Elem::size; i++) { + assert(__shfl_sync(~0U, regVec[m][i], idxQuad * 4) == regVec[m][i]); + if (enable) { + smemVec[gmma::instM * m + gmma::instM / 4 * warpRank + 8 * i + idxQuad] = regVec[m][i]; + } + } + } +} + +// for X +// order: 0,8,1,9, 2,10,3,11, 4,12,5,13, 6,14,7,15, ... +__device__ inline void storeGemm0AccToShm(uint32_t warpRank, uint32_t lane, + SharedMem::XBuffer& smemX, CtaBarrier& barConsumed, + Gemm0Acc const& acc) { + uint32_t const idxMat = lane / 8; + uint32_t const idxRow = lane % 8; + barConsumed.arrive_and_wait(); +#pragma unroll + for (uint32_t m = 0; m < Gemm0Acc::rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + Vec fp8Data; +#pragma unroll + for (uint32_t n = 0; n < exactDiv(Gemm0Acc::cols, 2); n++) { + reinterpret_cast&>(fp8Data[n]) = { + __nv_fp8x2_e4m3(float2{acc(m, n * 2)(i, 0), acc(m, n * 2 + 1)(i, 0)}), + __nv_fp8x2_e4m3(float2{acc(m, n * 2)(i, 1), acc(m, n * 2 + 1)(i, 1)})}; + } + static_assert(decltype(fp8Data)::size == 4); + stmatrix_4x(this_warp(), + &smemX[m].template at(16 * warpRank + 8 * i + idxRow, idxMat), + fp8Data); + } + } +} +#endif + +#if SWAP_AB +__device__ inline Vec loadVTileTransposed( + uint32_t warpRank, uint32_t lane, SharedMem::VBuffer const& smemV, uint32_t idxGmmaInstK) { + Vec fragA; + constexpr uint32_t instK = gmma::instK; +#pragma unroll + for (uint32_t i = 0; i < gemm1NbGmmaInstM; i++) { + static_assert(exactDiv(gmma::instM, gmmaWarpsPerGrp) == grainBytes); + constexpr uint32_t grainsPerPart = exactDiv(cacheHeadPartBytes, grainBytes); +#if CACHE_ELEM_ENUM == 0 + uint32_t idxRow = lane % 8; + uint32_t idxMat = lane / 8; + uint32_t c = idxMat % 2; + uint32_t r = idxMat / 2; + auto const col = BoundedVal<2 * gmmaWarpsPerGrp * gemm1NbGmmaInstM>{ + 2 * (gmmaWarpsPerGrp * i + warpRank) + c}; + auto const src = &smemV[col.template divBy().get()].template at( + instK * idxGmmaInstK + 8 * r + idxRow, col.template mod().get()); + auto const data = ldmatrix(src); + fragA[i] = reinterpret_cast(data); +#elif CACHE_ELEM_ENUM == 2 + auto const col = BoundedVal{gmmaWarpsPerGrp * i + warpRank}; + LdGrain const* src = &smemV[col.template divBy().get()].template at( + instK * idxGmmaInstK + lane, col.template mod().get()); + auto const data = ldmatrix(src); + fragA[i](0, 0)(0, 0) = prmt(data[0], data[1], {0, 4, 2, 6}); + fragA[i](0, 0)(1, 0) = prmt(data[0], data[1], {1, 5, 3, 7}); + fragA[i](0, 1)(0, 0) = prmt(data[2], data[3], {0, 4, 2, 6}); + fragA[i](0, 1)(1, 0) = prmt(data[2], data[3], {1, 5, 3, 7}); +#endif + } + return fragA; +} +#else +__device__ inline void transposeVTile(uint32_t warpRank, uint32_t lane, SharedMem::VTBuffer& dst, + SharedMem::VBuffer const& src) { + uint32_t const idxMat = lane / 8; + uint32_t const idxRow = lane % 8; +#pragma unroll + for (uint32_t m = 0; m < exactDiv(SharedMem::VTBuffer::rows, gmma::instM); m++) { + static_assert(cacheHeadPartElems >= gmma::instM); + uint32_t const idxPart = gmma::instM * m / cacheHeadPartElems; + constexpr uint32_t grainsPerCacheHeadPart = exactDiv(cacheHeadPartElems, cacheElemsPerGrain); +#pragma unroll + for (uint32_t n = 0; n < exactDiv(SharedMem::VTBuffer::cols, 2); n++) { + LdGrain const a = ldmatrix_4x( + this_warp(), &src[idxPart].template at( + 32 * n + lane, exactDiv(gmma::instM, cacheElemsPerGrain) * m - + grainsPerCacheHeadPart * idxPart + warpRank)); + LdGrain const b = {prmt(a[0], a[1], {0, 4, 2, 6}), prmt(a[0], a[1], {1, 5, 3, 7}), + prmt(a[2], a[3], {0, 4, 2, 6}), prmt(a[2], a[3], {1, 5, 3, 7})}; + uint32_t const i = idxMat % 2; + uint32_t const j = idxMat / 2; + stmatrix_4x( + this_warp(), + &dst.template at(gmma::instM * m + 16 * warpRank + 8 * i + idxRow, 2 * n + j), b); + } + } +} +#endif + +#if SWAP_AB +__device__ inline Vec loadShmColWiseVecNoDup( + ShmQWiseVec const& shmVec) { + Vec ret; +#pragma unroll + for (uint32_t i = 0; i < divUp(ShmQWiseVec::size, warp_size); i++) { + uint32_t const idx = i * warp_size + laneId(); + bool const inBound = ((ShmQWiseVec::size % warp_size == 0) || (idx < ShmQWiseVec::size)); + ret[i] = (inBound ? shmVec[idx] : 0); + } + return ret; +} + +__device__ inline void storeShmColWiseVecNoDup( + ShmQWiseVec& shmVec, Vec const& src) { +#pragma unroll + for (uint32_t i = 0; i < divUp(ShmQWiseVec::size, warp_size); i++) { + uint32_t const idx = i * warp_size + laneId(); + bool const inBound = ((ShmQWiseVec::size % warp_size == 0) || (idx < ShmQWiseVec::size)); + if (inBound) { + shmVec[idx] = src[i]; + } + } +} +#else +__device__ inline Vec +loadShmRowWiseVecNoDup(uint32_t warpRank, ShmQWiseVec const& shmVec) { + constexpr uint32_t const nbElems = exactDiv(ShmQWiseVec::size, gmma::instM) * (gmma::instM / 4); + Vec ret; + uint32_t const lane = laneId(); + uint32_t const idxHalf = lane / (gmma::instM / 4); + uint32_t const idxInHalf = lane % (gmma::instM / 4); +#pragma unroll + for (uint32_t i = 0; i < divUp(nbElems, warp_size); i++) { + uint32_t const idx = + gmma::instM * 2 * i + gmma::instM * idxHalf + (gmma::instM / 4) * warpRank + idxInHalf; + bool const inBound = ((nbElems % warp_size == 0) || (i + 1 < divUp(nbElems, warp_size)) || + (idx < ShmQWiseVec::size)); + ret[i] = (inBound ? shmVec[idx] : 0); + } + return ret; +} + +__device__ inline void storeShmRowWiseVecNoDup( + uint32_t warpRank, ShmQWiseVec& shmVec, + Vec const& src) { + constexpr uint32_t const nbElems = exactDiv(ShmQWiseVec::size, gmma::instM) * (gmma::instM / 4); + Vec ret; + uint32_t const lane = laneId(); + uint32_t const idxHalf = lane / (gmma::instM / 4); + uint32_t const idxInHalf = lane % (gmma::instM / 4); +#pragma unroll + for (uint32_t i = 0; i < divUp(nbElems, warp_size); i++) { + uint32_t const idx = + gmma::instM * 2 * i + gmma::instM * idxHalf + (gmma::instM / 4) * warpRank + idxInHalf; + bool const inBound = ((nbElems % warp_size == 0) || (i + 1 < divUp(nbElems, warp_size)) || + (idx < ShmQWiseVec::size)); + if (inBound) { + shmVec[idx] = src[i]; + } + } +} +#endif + +#if SWAP_AB +__device__ inline void rescaleGemm1AccForNewColMax_sync( + uint32_t warpRank, ShmQWiseVec const& shmXColMax, ShmQWiseVec const (&shmXColSum)[gemm0NbWarps], + ShmQWiseVec& shmAccColMax, Gemm1Acc& acc, ShmQWiseVec& shmAccColSum, + CtaBarrier& gemm1WarpGrpBar) { + auto accColSum = loadShmColWiseVecNoDup(shmAccColSum); + + auto const xColMax = loadShmColWiseVecNoDup(shmXColMax); + auto const accColMax = loadShmColWiseVecNoDup(shmAccColMax); + auto token = gemm1WarpGrpBar.arrive(); + auto const needRescaleVec = (accColMax < xColMax); + UniformNeedRescaleMask rescaleMask; + bool anyNeedRescale = false; +#pragma unroll + for (uint32_t i = 0; i < rescaleMask.size; i++) { + assert(accColMax[i] <= xColMax[i]); + rescaleMask[i] = __ballot_sync(~0U, needRescaleVec[i]); + anyNeedRescale = anyNeedRescale || (rescaleMask[i] != 0); + } + if (anyNeedRescale) { + auto const scaleVec = expf(accColMax - xColMax); + auto const lane = laneId(); +#pragma unroll + for (uint32_t n = 0; n < Gemm1Acc::cols; n++) { + uint32_t const vecIdx = gmma::instNBase * n / warp_size; + uint32_t const offset = gmma::instNBase * n % warp_size; + constexpr uint32_t nbThrdsPerInstNBase = exactDiv(gmma::instNBase, GmmaAccCoreMat::cols); +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + auto const mask = ((rescaleMask[vecIdx] >> (offset + j)) & 0b01010101U); + auto getScale = [&] { + return __shfl_sync(~0U, scaleVec[vecIdx], + offset + lane % nbThrdsPerInstNBase * GmmaAccCoreMat::cols + j); + }; + assert((getScale() != 1) == + ((mask >> lane % nbThrdsPerInstNBase * GmmaAccCoreMat::cols) & 0x1U)); + bool const needRescale = (mask != 0); + if (!needRescale) { // this branch is warp-uniform + continue; + } + float const scale = getScale(); +#pragma unroll + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + acc(m, n)(i, j) *= scale; + } + } + } + } + accColSum = accColSum * scaleVec; + } + gemm1WarpGrpBar.wait(mha::move(token)); + + // @fixme: with atomic, we can let the first warp reaching here to do the update, instead of + // always warp 3. + uint32_t const warpRankForUpdate = gmmaWarpsPerGrp - 1; + if (warpRank == warpRankForUpdate) { + if (anyNeedRescale) { + storeShmColWiseVecNoDup(shmAccColMax, xColMax); + } +#pragma unroll + for (uint32_t i = 0; i < gemm0NbWarps; i++) { + accColSum = accColSum + loadShmColWiseVecNoDup(shmXColSum[i]); + } + storeShmColWiseVecNoDup(shmAccColSum, accColSum); + } + gemm1WarpGrpBar.arrive_and_wait(); +} +#else +__device__ inline void rescaleGemm1AccForNewRowMax_sync(uint32_t warpRank, + ShmQWiseVec const& shmXRowMax, + ShmQWiseVec const& shmXRowSum, + ShmQWiseVec& shmAccRowMax, Gemm1Acc& acc, + ShmQWiseVec& shmAccRowSum) { + auto accRowSum = loadShmRowWiseVecNoDup(warpRank, shmAccRowSum); + auto const xRowMax = loadShmRowWiseVecNoDup(warpRank, shmXRowMax); + auto const accRowMax = loadShmRowWiseVecNoDup(warpRank, shmAccRowMax); + assert(all(xRowMax >= accRowMax)); + auto const needRescaleVec = (accRowMax < xRowMax); + UniformNeedRescaleMask rescaleMask; + bool anyNeedRescale = false; +#pragma unroll + for (uint32_t i = 0; i < rescaleMask.size; i++) { + assert(accRowMax[i] <= xRowMax[i]); + rescaleMask[i] = __ballot_sync(~0U, needRescaleVec[i]); + anyNeedRescale = anyNeedRescale || (rescaleMask[i] != 0); + } + + if (anyNeedRescale) { + auto const scaleVec = expf(accRowMax - xRowMax); + auto const lane = laneId(); +#pragma unroll + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + uint8_t const mask = reinterpret_cast(rescaleMask[m / 2])[m % 2][i]; + bool const needRescale = (mask != 0); + if (needRescale) { // this branch is warp-uniform + float const scale = __shfl_sync(~0U, scaleVec[m / 2], 16 * (m % 2) + 8 * i + lane / 4); +#pragma unroll + for (uint32_t n = 0; n < Gemm1Acc::cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + acc(m, n)(i, j) *= scale; + } + } + } + } + } + accRowSum = accRowSum * scaleVec; + } + __syncwarp(); + auto const xRowSum = loadShmRowWiseVecNoDup(warpRank, shmXRowSum); + storeShmRowWiseVecNoDup(warpRank, shmAccRowSum, accRowSum + xRowSum); + storeShmRowWiseVecNoDup(warpRank, shmAccRowMax, xRowMax); + __syncwarp(); +} +#endif + +#if SWAP_AB +__device__ inline void rescaleAcc(Gemm1Acc& acc, RegColWiseVec const& scale) { +#pragma unroll + for (uint32_t n = 0; n < Gemm1Acc::cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { +#pragma unroll + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + acc(m, n)(i, j) *= scale[n][j]; + } + } + } + } +} +#else +__device__ inline void rescaleAcc(Gemm1Acc& acc, RegRowWiseVec const& scale) { +#pragma unroll + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { +#pragma unroll + for (uint32_t n = 0; n < Gemm1Acc::cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + acc(m, n)(i, j) *= scale[m][i]; + } + } + } + } +} +#endif + +#if SWAP_AB +// @fixme: consider make this noinline +template +__device__ inline void saveTransposedOutput(uint32_t threadRank, uint32_t warpRank, DstHead* dst, + SharedMem::OutSwizzleBuf& swizzleBuf, + Gemm1Acc const& acc, CtaBarrier& warpGrpBar, + uint32_t nbKHeads) { + uint32_t const lane = laneId(); +#if CACHE_ELEM_ENUM == 0 + uint32_t const idxMat = lane / 8; + uint32_t const idxRow = lane % 8; +#elif CACHE_ELEM_ENUM == 2 + uint32_t const idxQuad = lane / 4; + uint32_t const idxInQuad = lane % 4; +#endif +#pragma unroll + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { +#pragma unroll + for (uint32_t n = 0; n < Gemm1Acc::cols; n++) { + auto const& core = acc(m, n); +#if CACHE_ELEM_ENUM == 0 + Vec f16Core; + reinterpret_cast&>(f16Core) = + convert(reinterpret_cast const&>(core)); + auto const dst = idxMat < 2 + ? &swizzleBuf.template at( + 8 * n + idxRow, 2 * (gmmaWarpsPerGrp * m + warpRank) + idxMat) + : nullptr; + stmatrix(dst, f16Core); +#elif CACHE_ELEM_ENUM == 2 + // each row is part of a b16 8x8 matrix and is transposed + Array2D coreTrans; + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + static_assert(GmmaAccCoreMat::cols == 2 && sizeof(InputElem) == 2); + InputElem2 const coreRow = float2ToInputElem2({core(i, 0), core(i, 1)}); + auto const coreRowTrans = movmatrix(reinterpret_cast(coreRow)); + reinterpret_cast(coreTrans(i, 0)) = coreRowTrans; + } + // expect compiler to generate two PRMT instructions + Vec const data = {coreTrans(0, 0), coreTrans(1, 0), coreTrans(0, 1), + coreTrans(1, 1)}; + swizzleBuf.template at( + gmma::instNBase * n + idxQuad, + (gmma::instM * m + exactDiv(gmma::instM, gmmaWarpsPerGrp) * warpRank) / 16)[idxInQuad] = + data; +#endif + } + } + warpGrpBar.arrive_and_wait(); + + constexpr uint32_t headsPerIter = exactDiv(grainBytes * gemm1NbThrds, paddedInputHeadBytes); + constexpr uint32_t nbIters = divUp(ctaNbValidQHeads, headsPerIter); + constexpr uint32_t nbWholeIters = ctaNbValidQHeads / headsPerIter; + constexpr uint32_t nbGrainsPerHead = exactDiv(paddedInputHeadBytes, grainBytes); + uint32_t const idxHeadBase = threadRank / nbGrainsPerHead; + uint32_t const idxGrain = threadRank % nbGrainsPerHead; +#pragma unroll + for (uint32_t iter = 0; iter < nbIters; iter++) { + uint32_t const idxHead = idxHeadBase + iter * headsPerIter; + if ((iter < nbWholeIters || idxHead < ctaNbValidQHeads) && + (!isHeadPadded || idxGrain < grainsPerIOHead)) { +#if CACHE_ELEM_ENUM == 0 + auto const data = swizzleBuf.template at(idxHead, idxGrain); +#elif CACHE_ELEM_ENUM == 2 + auto const data = reinterpret_cast&>( + swizzleBuf.template at(idxHead, idxGrain / 2))[idxGrain % 2]; +#endif + constexpr uint32_t inputElemsPerGrain = exactDiv(grainBytes, inputElemSize); + auto const outVec = convert( + reinterpret_cast const&>(data)); + uint32_t dstHeadIdx = idxHead; +#ifdef SPEC_Q_SEQ_LEN + if constexpr (dstIsStrided) { + uint32_t const idxToken = idxHead / headGrpSize; + if (idxToken < SPEC_Q_SEQ_LEN) { + uint32_t const strideBetweenTokens = nbKHeads * headGrpSize; + dstHeadIdx = idxToken * strideBetweenTokens + (idxHead % headGrpSize); + } + } +#endif + reinterpret_cast, nbGrainsPerHead>&>( + dst[dstHeadIdx])[idxGrain] = outVec; + } + } +} + +template +__device__ inline void finalizeAndWriteOut_sync( + uint32_t threadRank, uint32_t warpRank, DstHead* dst, SharedMem::OutSwizzleBuf& swizzleBuf, + Gemm1Acc& acc, float xvoScale, CtaBarrier& warpGrpBar, ShmQWiseVec const& accColSum, + ShmQWiseVec const& accColMax, ShmQWiseVec const* attentionSinksVec, uint32_t nbKHeads) { + // @fixme: if ctaNbQHeads is large, use loadShmColWiseVecNoDup + rcp + shfl to avoid 8x waste of + // mufu.rcp static_assert(ctaNbQHeads <= 8, "Warning: consider using loadShmColWiseVecNoDup + rcp + // + shfl to avoid 8x waste of mufu.rcp"); + auto regColSum = loadShmColWiseVecWithDup(accColSum); + if (attentionSinksVec != nullptr) { + auto const regAccColMax = loadShmColWiseVecWithDup(accColMax); + auto const regAttentionSinks = loadGmemColWiseVecWithDup(attentionSinksVec[0], headGrpSize - 1); + auto regColSinks = expf(regAttentionSinks - regAccColMax); + regColSum = regColSum + regColSinks; + } + auto const regOutScale = __frcp_rn(regColSum) * xvoScale; + rescaleAcc(acc, regOutScale); + + saveTransposedOutput(threadRank, warpRank, dst, swizzleBuf, acc, + warpGrpBar, nbKHeads); + warpGrpBar.arrive_and_wait(); +} +#else +template +__device__ inline void finalizeAndWriteOut_sync( + uint32_t warpRank, DstHead* dst, SharedMem::OutSwizzleBuf& swizzleBuf, Gemm1Acc& acc, + float xvoScale, ShmQWiseVec const& accRowSum, + uint32_t nbKHeads /* for spec dec. set to 1 for workspace*/, uint32_t ctaNbValidTokens) { + auto const regRowSum = loadShmRowWiseVecWithDup(warpRank, accRowSum); + auto const regOutScale = __frcp_rn(regRowSum) * xvoScale; + rescaleAcc(acc, regOutScale); + + using DstElem = typename DstHead::Elem; + auto const lane = laneId(); + uint32_t const idxQuad = lane / 4; + uint32_t const idxInQuad = lane % 4; + using Atom = Vec, 4>; + using SwizzleBuf = Array2D, 4>, ctaNbQHeads, exactDiv(headElems, 4 * 4)>; + static_assert(sizeof(SwizzleBuf) <= sizeof(swizzleBuf)); + auto& buf = reinterpret_cast(swizzleBuf); +#pragma unroll + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + uint32_t const r = gmma::instM * m + 16 * warpRank + 8 * i + idxQuad; + static_assert(SwizzleBuf::cols == exactDiv(Gemm1Acc::cols, 2)); +#pragma unroll + for (uint32_t n = 0; n < exactDiv(Gemm1Acc::cols, 2); n++) { + Vec const v = + convert(Vec{acc(m, n * 2)(i, 0), acc(m, n * 2 + 1)(i, 0), + acc(m, n * 2)(i, 1), acc(m, n * 2 + 1)(i, 1)}); + //@fixme: without reinterpret_cast to V, the compiler generates wrong code, and require a + //__syncwarp() + // after rescaleAcc() to work around. Likely a bug of the compiler. + //@todo: report a compiler bug. + using V = Vec; + reinterpret_cast(buf.template at(r, n)[idxInQuad]) = + reinterpret_cast(v); + // buf.template at(r, n)[idxInQuad] = v; + } + } + } + __syncwarp(); + +#pragma unroll + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { + constexpr uint32_t srcHeadBytes = sizeof(DstElem) * headElems; + constexpr uint32_t grpSize = exactDiv(srcHeadBytes, grainBytes); + constexpr uint32_t nbGrps = exactDiv(warp_size, grpSize); + uint32_t const idxGrp = lane / grpSize; + constexpr uint32_t grainsPerAtom = exactDiv(sizeof(Atom), grainBytes); + uint32_t const rowBase = gmma::instM * m + 16 * warpRank; + constexpr uint32_t totalNbGrains = grainsPerAtom * SwizzleBuf::cols * 16; + uint32_t const nbIters = divUp(totalNbGrains, nbGrps); + constexpr bool wholeIters = (totalNbGrains % nbGrps == 0); + constexpr bool wholeHeads = (validElemsPerHead == headElems); +#pragma unroll + for (uint32_t iter = 0; iter < nbIters; iter++) { + uint32_t const idxGrain = nbGrps * iter + idxGrp; + constexpr uint32_t grainsPerSrcHead = exactDiv(srcHeadBytes, grainBytes); + uint32_t const r = idxGrain / grainsPerSrcHead; + if (!wholeIters && r >= 16) { + break; + } + uint32_t const cGrain = idxGrain % grainsPerSrcHead; + uint32_t const cAtom = cGrain / grainsPerAtom; + constexpr uint32_t grainsPerDstHead = exactDiv(sizeof(DstHead), grainBytes); + uint32_t const glbRow = gmma::instM * m + 16 * warpRank + r; + if (ctaNbValidQHeads != ctaNbQHeads && glbRow >= ctaNbValidQHeads) { + break; + } + if (wholeHeads || cGrain < grainsPerDstHead) { + uint32_t const srcRow = rowBase + r; + auto const data = reinterpret_cast( + buf.template at(srcRow, cAtom))[cGrain % grainsPerAtom]; +#if SPEC_DEC + static_assert(beamWidth == 1); + uint32_t const idxToken = srcRow / headGrpSize; // inside CTA + if (idxToken >= ctaNbValidTokens) { + break; + } + uint32_t const tokenPad = headGrpSize * (nbKHeads - 1); + uint32_t const dstRow = srcRow + idxToken * tokenPad; +#else + uint32_t const dstRow = srcRow; +#endif + reinterpret_cast(dst[dstRow])[cGrain] = data; + } + } + } +} +#endif + +template +__device__ inline Vec, ropeNbPairsPerThrd> loadHead( + Vec const& head, uint32_t tid) { + constexpr uint32_t nbPairs = exactDiv(validElemsPerHead, 2); + constexpr uint32_t nbPairsPerThrd = ropeNbPairsPerThrd; + constexpr uint32_t nbWorkingThrds = exactDiv(nbPairs, nbPairsPerThrd); + bool const isWorkingThrd = (nbWorkingThrds == nbThrds || tid < nbWorkingThrds); + static_assert(nbPairs % nbPairsPerThrd == 0); + Vec, nbPairsPerThrd> ret; + if constexpr (forNeox) { + auto const& pairs = + reinterpret_cast, nbWorkingThrds>, 2> const&>(head); + auto const data = isWorkingThrd + ? Vec, 2>{pairs[0][tid], pairs[1][tid]} + : Vec, 2>{}; + Vec, 2> const tmp = {convert(data[0]), + convert(data[1])}; +#pragma unroll + for (uint32_t i = 0; i < nbPairsPerThrd; i++) { + ret[i][0] = tmp[0][i]; + ret[i][1] = tmp[1][i]; + } + } else { + auto const data = + isWorkingThrd ? reinterpret_cast, nbPairsPerThrd> const*>(&head)[tid] + : Vec, nbPairsPerThrd>{}; +#pragma unroll + for (uint32_t i = 0; i < nbPairsPerThrd; i++) { + ret[i] = convert(data[i]); + } + } + return ret; +} + +template +__device__ inline mha::conditional_t, 2>, + Vec, nbPairsPerThrd>> +applyRoPE(Vec, nbPairsPerThrd> const& data, + Vec, nbPairsPerThrd> const& ropeCosSin) { + Vec, nbPairsPerThrd> r; +#pragma unroll + for (uint32_t i = 0; i < nbPairsPerThrd; i++) { + float const x = data[i][0]; + float const y = data[i][1]; + float const c = ropeCosSin[i][0]; + float const s = ropeCosSin[i][1]; + r[i] = Vec{c * x - s * y, s * x + c * y}; + } + if constexpr (forNeox) { + Vec, 2> tmp; +#pragma unroll + for (uint32_t i = 0; i < nbPairsPerThrd; i++) { + tmp[0][i] = r[i][0]; + tmp[1][i] = r[i][1]; + } + return Vec, 2>{convert(tmp[0]), + convert(tmp[1])}; + } else { + Vec, nbPairsPerThrd> ret; +#pragma unroll + for (uint32_t i = 0; i < nbPairsPerThrd; i++) { + ret[i] = convert(r[i]); + } + return ret; + } +} + +template +__device__ inline void storeRotatedPairsForKV( + GMemCacheHead& dst, + mha::conditional_t>, 2>, + Vec, ropeNbPairsPerThrd>> const& src, + uint32_t tid) { + constexpr uint32_t nbPairs = exactDiv(validElemsPerHead, 2); + constexpr uint32_t nbPairsPerThrd = ropeNbPairsPerThrd; + constexpr uint32_t nbWorkingThrds = exactDiv(nbPairs, nbPairsPerThrd); + bool const isWorkingThrd = (nbWorkingThrds == nbThrds || tid < nbWorkingThrds); + static_assert(nbPairs % nbPairsPerThrd == 0); + if (!isWorkingThrd) { + return; + } + if constexpr (forNeox) { + auto& pairs = + reinterpret_cast, nbWorkingThrds>, 2>&>(dst); + pairs[0][tid] = src[0]; + pairs[1][tid] = src[1]; + } else { + reinterpret_cast, nbPairsPerThrd>*>(&dst)[tid] = src; + } +} + +template +__device__ inline void storeRotatedPairsForQ( + SharedMem::QBuffer& dst, + mha::conditional_t>, 2>, + Vec, ropeNbPairsPerThrd>> const& src, + uint32_t row, uint32_t tid) { + constexpr uint32_t nbPairs = exactDiv(validElemsPerHead, 2); + constexpr uint32_t nbPairsPerThrd = ropeNbPairsPerThrd; + constexpr uint32_t nbWorkingThrds = exactDiv(nbPairs, nbPairsPerThrd); + bool const isWorkingThrd = (nbWorkingThrds == nbThrds || tid < nbWorkingThrds); + static_assert(nbPairs % nbPairsPerThrd == 0); + if (isWorkingThrd) { + if constexpr (forNeox) { +#pragma unroll + for (uint32_t i = 0; i < 2; i++) { + auto const byteOffset = + BoundedVal{cacheElemSize * nbPairsPerThrd * (nbWorkingThrds * i + tid)}; + uint32_t const idxPart = byteOffset.template divBy().get(); + auto const byteOffsetInsidePart = byteOffset.template mod(); + uint32_t const idxGrain = byteOffsetInsidePart.template divBy().get(); + LdGrain& grain = dst[idxPart].template at(row, idxGrain); + uint32_t const byteOffsetInsideGrain = + byteOffsetInsidePart.template mod().get(); + static_assert(cacheElemSize * nbPairsPerThrd <= grainBytes && + grainBytes % (cacheElemSize * nbPairsPerThrd) == 0); + reinterpret_cast&>( + reinterpret_cast(&grain)[byteOffsetInsideGrain]) = src[i]; + } + } else { + auto const byteOffset = BoundedVal{cacheElemSize * 2 * nbPairsPerThrd * tid}; + uint32_t const idxPart = byteOffset.template divBy().get(); + auto const byteOffsetInsidePart = byteOffset.template mod(); + uint32_t const idxGrain = byteOffsetInsidePart.template divBy().get(); + LdGrain& grain = dst[idxPart].template at(row, idxGrain); + uint32_t const byteOffsetInsideGrain = byteOffsetInsidePart.template mod().get(); + static_assert(cacheElemSize * 2 * nbPairsPerThrd <= grainBytes && + grainBytes % (cacheElemSize * 2 * nbPairsPerThrd) == 0); + reinterpret_cast, nbPairsPerThrd>&>( + reinterpret_cast(&grain)[byteOffsetInsideGrain]) = src; + } + } + static_assert(validElemsPerHead % 16 == 0); + __syncwarp(); + if constexpr (validElemsPerHead < headElems) { + static_assert(validElemsPerHead >= headElems - exactDiv(headElems, nbQParts)); + constexpr uint32_t nbPadGrainsPerHead = + exactDiv(headElems - validElemsPerHead, cacheElemsPerGrain); + constexpr uint32_t nbPadGrains = nbPadGrainsPerHead * ctaNbQHeads; + uint32_t const nbIters = divUp(nbPadGrains, nbThrds); +#pragma unroll + for (uint32_t iter = 0; iter < nbIters; iter++) { + uint32_t idx = tid + nbThrds * iter; + if (idx >= nbPadGrains) { + break; + } + uint32_t const r = idx / nbPadGrainsPerHead; + uint32_t const c = grainsPerQPart - nbPadGrainsPerHead + idx % nbPadGrainsPerHead; + dst[dst.size - 1].template at(r, c) = LdGrain{}; + } + } +} + +#ifndef GENERATE_CUBIN +void launchHopperF8MHA( + cudaDeviceProp const& prop, uint32_t nbKHeads, +#if SLIDING_WINDOW + uint32_t slidingWinSize, +#endif + float qScale, OutputHead* output, +#if LOW_PREC_OUTPUT + float const* rcpOutScale, +#endif +#if USE_INPUT_KV + InputHead const* qkv, +#if ROPE_STYLE != 0 + Vec const* ropeCosSin, +#endif +#else + InputHead const* q, +#endif + float const* attentionSinks, // [headGrpSize] +#if USE_PAGED_KV_CACHE +#if PAGED_KV_CACHE_LAYOUT == 1 + GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, +#else + GMemCacheHead* pool, // global pool of pages +#endif + KVCachePageIndex const* + kvCachePageList, // device pointer. shape: + // KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] +#else + GMemKVCacheHead* kvCacheData, +#endif + uint32_t maxSeqLen, uint32_t const* seqLen, +#if USE_BEAM_SEARCH + BeamSearchParams const& beamSearchParams, +#endif + uint32_t batchSize, + float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. + // Used only for int8/fp8 KV cache. +#if SPEC_DEC + SpecDecParams const& specDecParams, +#endif + uint32_t* semaphores, void* scratch, cudaStream_t stream) { + if (beamWidth != 1) { + throw std::runtime_error("not implemented"); + } + static uint32_t const hostSmemSize = [&]() { + uint32_t size; + checkCuda(cudaMemcpyFromSymbol(&size, smemSize, sizeof(smemSize))); + checkCuda(cudaFuncSetAttribute(kernel_mha, cudaFuncAttributeMaxDynamicSharedMemorySize, size)); + return size; + }(); + // printf("smemSize = %u\n", hostSmemSize); + uint32_t const nbVHeads = nbKHeads; + uint32_t const nbQHeads = nbKHeads * headGrpSize; + uint32_t const nbQKVHeads = nbQHeads + nbKHeads + nbVHeads; + uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t { + auto const env = std::getenv("XQA_NB_SUB_SEQ"); + if (env != nullptr) { + int32_t const val = std::stoi(env); + if (val > 0) { + return val; + } + } + float const factor = 0.25f; + return mha::min( + mha::max( + 1U, (uint32_t)round(prop.multiProcessorCount * 3 / (batchSize * nbKHeads) * factor)), + divUp(maxSeqLen, gemm0CtaTileNbTokens)); + }(); +#if SPEC_DEC + uint32_t const qSeqLen = specDecParams.qSeqLen; +#else + uint32_t const qSeqLen = 1; +#endif + // gridDim.z == nbKHeads * batchSize && gridDim.y == nbSubSeqPerSeq && gridDim.x == + // nbInputSeqSplit + dim3 const dimGrid{divUp(qSeqLen, inputTokensPerCta), nbSubSeqPerSeq, nbKHeads * batchSize}; + dim3 const dimCta{warp_size * gmmaWarpsPerGrp, 1, 3}; + auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0); +#if USE_PAGED_KV_CACHE + uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage); + auto const dtype = [] { + if (std::is_same_v) { + return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + } else if (std::is_same_v) { + return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + } else if (std::is_same_v) { + return CU_TENSOR_MAP_DATA_TYPE_UINT8; + } + throw std::runtime_error("unsupported cache element type"); + }(); + +#if PAGED_KV_CACHE_LAYOUT == 1 + KVCacheList const cacheList{kCacheVLLM, vCacheVLLM, kvCachePageList, seqLen, + maxNbPagesPerSeq}; + + auto const tensorMapVLLMK = + makeTensorMapForPagedKVCache(kCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, + cacheHeadPartElems, gemm0CtaTileNbTokens); + auto const tensorMapVLLMV = + makeTensorMapForPagedKVCache(vCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, + cacheHeadPartElems, gemm0CtaTileNbTokens); +#else + KVCacheList const cacheList{pool, kvCachePageList, seqLen, maxNbPagesPerSeq}; + auto const tensorMap = + makeTensorMapForPagedKVCache(pool, dtype, validElemsPerHead, nbKHeads, tokensPerPage, + cacheHeadPartElems, gemm0CtaTileNbTokens); +#endif + + cudaError_t const err = cudaLaunchKernelEx(&launchCfg, &kernel_mha, nbKHeads, +#if SLIDING_WINDOW + slidingWinSize, +#endif + qScale, output, +#if LOW_PREC_OUTPUT + rcpOutScale, +#endif +#if USE_INPUT_KV + qkv, +#if ROPE_STYLE != 0 + ropeCosSin, +#endif +#else + q, +#endif + attentionSinks, cacheList, +#if USE_BEAM_SEARCH + beamSearchParams, +#endif + batchSize, kvCacheScale, +#if PAGED_KV_CACHE_LAYOUT == 1 + tensorMapVLLMK, tensorMapVLLMV, +#else + tensorMap, +#endif +#if SPEC_DEC + specDecParams, +#endif + semaphores, scratch); +#else + KVCacheList const cacheList{kvCacheData, seqLen, maxSeqLen}; + static_assert(!usePagedKVCache); + assert(gemm0CtaTileNbTokens == gemm1CtaTileNbTokens); + auto const tensorMap = makeTensorMapForContiguousKVCache( + kvCacheData, CU_TENSOR_MAP_DATA_TYPE_UINT8, validElemsPerHead, nbKHeads, maxSeqLen, beamWidth, + batchSize, cacheHeadPartElems, gemm0CtaTileNbTokens); + cudaError_t const err = + cudaLaunchKernelEx(&launchCfg, kernel_mha, nbKHeads, +#if SLIDING_WINDOW + slidingWinSize, +#endif + qScale, output, +#if LOW_PREC_OUTPUT + rcpOutScale, +#endif +#if USE_INPUT_KV + qkv, +#if ROPE_STYLE != 0 + ropeCosSin, +#endif +#else + q, +#endif + attentionSinks, cacheList, +#if USE_BEAM_SEARCH + beamSearchParams, +#endif + batchSize, kvCacheScale, tensorMap, semaphores, scratch); +#endif + checkCuda(err); +} +#endif + +void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, + uint32_t slidingWinSize, float qScale, OutputHead* output, +#if LOW_PREC_OUTPUT + float const* rcpOutScale, +#endif + InputHead const* q, float const* attentionSinks, + GMemCacheHead* pool, KVCachePageIndex const* kvCachePageList, + uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, + float const* __restrict__ kvCacheScale, +#if SPEC_DEC + uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask, +#endif + uint32_t* semaphores, void* scratch, cudaStream_t stream) { + static uint32_t const hostSmemSize = [&]() { + uint32_t size; + checkCuda(cudaMemcpyFromSymbol(&size, smemSize, sizeof(smemSize))); + checkCuda(cudaFuncSetAttribute(kernel_mha, cudaFuncAttributeMaxDynamicSharedMemorySize, size)); + return size; + }(); + uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t { + float const factor = 0.25f; + return mha::min( + mha::max( + 1U, (uint32_t)round(multiProcessorCount * 3 / (batchSize * nbKHeads) * factor)), + divUp(maxSeqLen, gemm0CtaTileNbTokens)); + }(); +#if SPEC_DEC + auto specDecParams = SpecDecParams{qSeqLen, qCuSeqLens, mask}; + uint32_t const qLen = qSeqLen; +#else + uint32_t const qLen = 1; +#endif + dim3 const dimGrid{divUp(qLen, inputTokensPerCta), nbSubSeqPerSeq, nbKHeads * batchSize}; + dim3 const dimCta{warp_size * gmmaWarpsPerGrp, 1, 3}; + auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0); +#if USE_PAGED_KV_CACHE + uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage); + auto const dtype = [] { + if (std::is_same_v) { + return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + } else if (std::is_same_v) { + return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + } else if (std::is_same_v) { + return CU_TENSOR_MAP_DATA_TYPE_UINT8; + } + throw std::runtime_error("unsupported cache element type"); + }(); + +#if PAGED_KV_CACHE_LAYOUT == 1 + KVCacheList const cacheList{kCacheVLLM, vCacheVLLM, kvCachePageList, seqLen, + maxNbPagesPerSeq}; + + auto const tensorMapVLLMK = + makeTensorMapForPagedKVCache(kCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, + cacheHeadPartElems, gemm0CtaTileNbTokens); + auto const tensorMapVLLMV = + makeTensorMapForPagedKVCache(vCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, + cacheHeadPartElems, gemm0CtaTileNbTokens); +#else + KVCacheList const cacheList{pool, kvCachePageList, seqLen, maxNbPagesPerSeq}; + auto const tensorMap = + makeTensorMapForPagedKVCache(pool, dtype, validElemsPerHead, nbKHeads, tokensPerPage, + cacheHeadPartElems, gemm0CtaTileNbTokens); +#endif + + cudaError_t const err = cudaLaunchKernelEx(&launchCfg, &kernel_mha, nbKHeads, +#if SLIDING_WINDOW + slidingWinSize, +#endif + qScale, output, +#if LOW_PREC_OUTPUT + rcpOutScale, +#endif + q, attentionSinks, cacheList, batchSize, kvCacheScale, +#if PAGED_KV_CACHE_LAYOUT == 1 + tensorMapVLLMK, tensorMapVLLMV, +#else + tensorMap, +#endif +#if SPEC_DEC + specDecParams, +#endif + semaphores, scratch); +#else + KVCacheList const cacheList{kvCacheData, seqLen, maxSeqLen}; + static_assert(!usePagedKVCache); + assert(gemm0CtaTileNbTokens == gemm1CtaTileNbTokens); + auto const tensorMap = makeTensorMapForContiguousKVCache( + kvCacheData, CU_TENSOR_MAP_DATA_TYPE_UINT8, validElemsPerHead, nbKHeads, maxSeqLen, beamWidth, + batchSize, cacheHeadPartElems, gemm0CtaTileNbTokens); + cudaError_t const err = cudaLaunchKernelEx(&launchCfg, kernel_mha, nbKHeads, +#if SLIDING_WINDOW + slidingWinSize, +#endif + qScale, output, +#if LOW_PREC_OUTPUT + rcpOutScale, +#endif + q, attentionSinks, cacheList, batchSize, kvCacheScale, + tensorMap, semaphores, scratch); +#endif + checkCuda(err); +} +#endif diff --git a/csrc/xqa/tensorMap.cpp b/csrc/xqa/tensorMap.cpp new file mode 100644 index 0000000000..e79272b018 --- /dev/null +++ b/csrc/xqa/tensorMap.cpp @@ -0,0 +1,117 @@ +#include "tensorMap.h" + +#include +#include + +#include + +#include "utils.h" + +uint32_t getElemBytes(CUtensorMapDataType_enum dataType) { + switch (dataType) { + case CU_TENSOR_MAP_DATA_TYPE_UINT8: + return 1; + case CU_TENSOR_MAP_DATA_TYPE_UINT16: + return 2; + case CU_TENSOR_MAP_DATA_TYPE_UINT32: + return 4; + case CU_TENSOR_MAP_DATA_TYPE_INT32: + return 4; + case CU_TENSOR_MAP_DATA_TYPE_UINT64: + return 8; + case CU_TENSOR_MAP_DATA_TYPE_INT64: + return 8; + case CU_TENSOR_MAP_DATA_TYPE_FLOAT16: + return 2; + case CU_TENSOR_MAP_DATA_TYPE_FLOAT32: + return 4; + case CU_TENSOR_MAP_DATA_TYPE_FLOAT64: + return 8; + case CU_TENSOR_MAP_DATA_TYPE_BFLOAT16: + return 2; + case CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ: + return 4; + case CU_TENSOR_MAP_DATA_TYPE_TFLOAT32: + return 4; + case CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ: + return 4; + default: + throw std::runtime_error("unsupported data type"); + } +} + +CUtensorMap makeTensorMapForContiguousKVCache(void const* addr, CUtensorMapDataType_enum dataType, + uint32_t headElems, uint32_t nbKHeads, + uint32_t maxCacheLen, uint32_t beamWidth, + uint32_t batchSize, uint32_t partElems, + uint32_t nbTokens) { + CUtensorMap tensorMap{}; + uint64_t const globalDims[] = {headElems, maxCacheLen, nbKHeads, 2 * beamWidth * batchSize}; + uint32_t elemBytes = getElemBytes(dataType); + uint32_t const headBytes = elemBytes * headElems; + uint64_t const globalStrides[] = {headBytes, headBytes * maxCacheLen, + headBytes * maxCacheLen * nbKHeads}; + uint32_t const boxDims[] = {partElems, nbTokens, 1, 1}; + uint32_t const elemStrides[] = {1, 1, 1, 1}; + + auto const swizzle = [&] { + switch (partElems) { + case 128: + return CU_TENSOR_MAP_SWIZZLE_128B; + case 64: + return CU_TENSOR_MAP_SWIZZLE_64B; + default: + throw std::runtime_error("unsupported cache head size"); + } + }(); + + checkCu(cuTensorMapEncodeTiled(&tensorMap, dataType, 4, const_cast(addr), globalDims, + globalStrides, boxDims, elemStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzle, CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + return tensorMap; +} + +CUtensorMap makeTensorMapForPagedKVCache(void const* addr, CUtensorMapDataType_enum dataType, + uint32_t headElems, uint32_t nbKHeads, + uint32_t tokensPerPage, uint32_t partElems, + uint32_t nbTokensPerTile) { + CUtensorMap tensorMap{}; + uint32_t elemBytes = getElemBytes(dataType); +// VLLM Layout +#if PAGED_KV_CACHE_LAYOUT == 1 + uint64_t const globalDims[] = {headElems, nbKHeads, tokensPerPage, 1U << 31}; + uint32_t const headBytes = elemBytes * headElems; + uint64_t const globalStrides[] = {headBytes, headBytes * nbKHeads, + headBytes * nbKHeads * tokensPerPage}; + uint32_t const partBytes = partElems * elemBytes; + uint32_t const boxDims[] = {partElems, 1, mha::min(tokensPerPage, nbTokensPerTile), 1}; + uint32_t const elemStrides[] = {1, 1, 1, 1}; + // XQA Original Layout +#else + uint64_t const globalDims[] = {headElems, tokensPerPage, nbKHeads, 1U << 31}; + uint32_t const headBytes = elemBytes * headElems; + uint64_t const globalStrides[] = {headBytes, headBytes * tokensPerPage, + headBytes * tokensPerPage * nbKHeads}; + uint32_t const partBytes = partElems * elemBytes; + uint32_t const boxDims[] = {partElems, mha::min(tokensPerPage, nbTokensPerTile), 1, 1}; + uint32_t const elemStrides[] = {1, 1, 1, 1}; +#endif + + auto const swizzle = [&] { + switch (partBytes) { + case 128: + return CU_TENSOR_MAP_SWIZZLE_128B; + case 64: + return CU_TENSOR_MAP_SWIZZLE_64B; + default: + throw std::runtime_error("unsupported cache head size"); + } + }(); + + checkCu(cuTensorMapEncodeTiled(&tensorMap, dataType, 4, const_cast(addr), globalDims, + globalStrides, boxDims, elemStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzle, CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + return tensorMap; +} diff --git a/csrc/xqa/tensorMap.h b/csrc/xqa/tensorMap.h new file mode 100644 index 0000000000..d0b2c76b96 --- /dev/null +++ b/csrc/xqa/tensorMap.h @@ -0,0 +1,16 @@ +#pragma once + +#include + +uint32_t getElemBytes(CUtensorMapDataType_enum dataType); + +CUtensorMap makeTensorMapForContiguousKVCache(void const* addr, CUtensorMapDataType_enum dataType, + uint32_t headElems, uint32_t nbKHeads, + uint32_t maxCacheLen, uint32_t beamWidth, + uint32_t batchSize, uint32_t partElems, + uint32_t nbTokens); + +CUtensorMap makeTensorMapForPagedKVCache(void const* addr, CUtensorMapDataType_enum dataType, + uint32_t headElems, uint32_t nbKHeads, + uint32_t tokensPerPage, uint32_t partElems, + uint32_t nbTokensPerTile); diff --git a/csrc/xqa/tma.h b/csrc/xqa/tma.h new file mode 100644 index 0000000000..5cf67238a2 --- /dev/null +++ b/csrc/xqa/tma.h @@ -0,0 +1,302 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include "cuda_hint.cuh" +#include "utils.h" +#ifndef GENERATE_CUBIN +#include +#include + +#include +#endif +#include "barriers.cuh" + +enum class StateSpace { kCONSTANT, kPARAMETER, kGENERIC }; + +#ifdef GENERATE_CUBIN +#define CU_TENSOR_MAP_NUM_QWORDS 16 + +typedef struct CUtensorMap_st { +#if defined(__cplusplus) && (__cplusplus >= 201103L) + alignas(64) +#elif __STDC_VERSION__ >= 201112L + _Alignas(64) +#endif + uint64_t opaque[CU_TENSOR_MAP_NUM_QWORDS]; +} CUtensorMap; +#endif + +namespace tma { + +__device__ inline void loadLinearAsync(void* dst, void const* src, uint32_t nbBytes, + CtaBarrier& bar) { + asm volatile( + "cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(src)), "r"(nbBytes), + "l"(__cvta_generic_to_shared(&bar)) + : "memory"); +} + +__device__ inline void prefetchLinear(void const* src, uint32_t nbBytes) { + asm volatile( + "cp.async.bulk.prefetch.L2.global [%0], %1;\n" ::"l"(reinterpret_cast(src)), + "r"(nbBytes) + : "memory"); +} + +// dsr and &bar must be remote address generated by mapa and src must be local address +__device__ inline void sm2smCopyAsync(void* dst, void const* src, uint32_t nbBytes, + CgaBarrier& bar) { + asm volatile( + "cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [%0], [%1], %2, " + "[%3];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(src)), "r"(nbBytes), + "l"(__cvta_generic_to_shared(&bar)) + : "memory"); +} + +template +__device__ inline void loadAsync(void* dst, CUtensorMap const& tensorMap, DimsLE offset, + CtaBarrier& bar) { + if constexpr (nbDims == 1) { + // nbDims==1 does not need tensormap and should just use cp.async.bulk + asm volatile( + "cp.async.bulk.tensor.1d.shared::cta.global.mbarrier::complete_tx::bytes.tile [%0], [%1, " + "{%2}], [%3];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), + "r"(offset[0]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else if constexpr (nbDims == 2) { + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_tx::bytes.tile [%0], [%1, " + "{%2, %3}], [%4];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), + "r"(offset[0]), "r"(offset[1]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else if constexpr (nbDims == 3) { + asm volatile( + "cp.async.bulk.tensor.3d.shared::cta.global.mbarrier::complete_tx::bytes.tile [%0], [%1, " + "{%2, %3, %4}], " + "[%5];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), + "r"(offset[0]), "r"(offset[1]), "r"(offset[2]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else if constexpr (nbDims == 4) { + asm volatile( + "cp.async.bulk.tensor.4d.shared::cta.global.mbarrier::complete_tx::bytes.tile [%0], [%1, " + "{%2, %3, %4, " + "%5}], [%6];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), + "r"(offset[0]), "r"(offset[1]), "r"(offset[2]), "r"(offset[3]), + "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else if constexpr (nbDims == 5) { + asm volatile( + "cp.async.bulk.tensor.5d.shared::cta.global.mbarrier::complete_tx::bytes.tile [%0], [%1, " + "{%2, %3, %4, %5, " + "%6}], [%7];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), + "r"(offset[0]), "r"(offset[1]), "r"(offset[2]), "r"(offset[3]), "r"(offset[4]), + "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else { + static_assert(nbDims >= 1 && nbDims <= 5); + } +} + +template +__device__ inline void loadAsync(void* dst, CUtensorMap const& tensorMap, DimsLE offset, + CtaBarrier& bar, uint64_t cacheHint) { + if constexpr (nbDims == 1) { + // nbDims==1 does not need tensormap and should just use cp.async.bulk + asm volatile( + "cp.async.bulk.tensor.1d.shared::cta.global.mbarrier::complete_tx::bytes.tile.L2::cache_" + "hint [%0], [%1, " + "{%2}], [%3], %4;\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), + "r"(offset[0]), "l"(__cvta_generic_to_shared(&bar)), "l"(cacheHint) + : "memory"); + } else if constexpr (nbDims == 2) { + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_tx::bytes.tile.L2::cache_" + "hint [%0], [%1, " + "{%2, %3}], [%4], %5;\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), + "r"(offset[0]), "r"(offset[1]), "l"(__cvta_generic_to_shared(&bar)), "l"(cacheHint) + : "memory"); + } else if constexpr (nbDims == 3) { + asm volatile( + "cp.async.bulk.tensor.3d.shared::cta.global.mbarrier::complete_tx::bytes.tile.L2::cache_" + "hint [%0], [%1, " + "{%2, %3, %4}], [%5], %6;\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), + "r"(offset[0]), "r"(offset[1]), "r"(offset[2]), "l"(__cvta_generic_to_shared(&bar)), + "l"(cacheHint) + : "memory"); + } else if constexpr (nbDims == 4) { + asm volatile( + "cp.async.bulk.tensor.4d.shared::cta.global.mbarrier::complete_tx::bytes.tile.L2::cache_" + "hint [%0], [%1, " + "{%2, %3, %4, %5}], [%6], %7;\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), + "r"(offset[0]), "r"(offset[1]), "r"(offset[2]), "r"(offset[3]), + "l"(__cvta_generic_to_shared(&bar)), "l"(cacheHint) + : "memory"); + } else if constexpr (nbDims == 5) { + asm volatile( + "cp.async.bulk.tensor.5d.shared::cta.global.mbarrier::complete_tx::bytes.tile.L2::cache_" + "hint [%0], [%1, " + "{%2, %3, %4, %5, %6}], [%7], %8;\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), + "r"(offset[0]), "r"(offset[1]), "r"(offset[2]), "r"(offset[3]), "r"(offset[4]), + "l"(__cvta_generic_to_shared(&bar)), "l"(cacheHint) + : "memory"); + } else { + static_assert(nbDims >= 1 && nbDims <= 5); + } +} + +// shared::cta -> global +__device__ inline void store1DAsync(void* dst, void const* src, uint32_t nbBytes) { + asm volatile("cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;\n" + : + : "l"(reinterpret_cast(dst)), "l"(__cvta_generic_to_shared(src)), + "r"(nbBytes)); +} + +template +__device__ inline void storeAsync(CUtensorMap const& tensorMap, DimsLE const& offset, + void* src) { + if constexpr (nbDims == 1) { + // nbDims==1 does not need tensormap and should just use cp.async.bulk + asm volatile("cp.async.bulk.tensor.1d.global.shared::cta.bulk_group.tile [%0, {%1}], [%2];\n" + : + : "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), + "l"(__cvta_generic_to_shared(src)) + : "memory"); + } else if constexpr (nbDims == 2) { + asm volatile( + "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2}], [%3];\n" + : + : "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), "r"(offset[1]), + "l"(__cvta_generic_to_shared(src)) + : "memory"); + } else if constexpr (nbDims == 3) { + asm volatile( + "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3}], [%4];\n" + : + : "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), "r"(offset[1]), + "r"(offset[2]), "l"(__cvta_generic_to_shared(src)) + : "memory"); + } else if constexpr (nbDims == 4) { + asm volatile( + "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3, %4}], [%5];\n" + : + : "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), "r"(offset[1]), + "r"(offset[2]), "r"(offset[3]), "l"(__cvta_generic_to_shared(src)) + : "memory"); + } else if constexpr (nbDims == 5) { + asm volatile( + "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3, %4, %5}], " + "[%6];\n" + : + : "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), "r"(offset[1]), + "r"(offset[2]), "r"(offset[3]), "r"(offset[4]), "l"(__cvta_generic_to_shared(src)) + : "memory"); + } else { + static_assert(nbDims >= 1 && nbDims <= 5); + } +} + +__device__ inline void setTensorMapGlbAddr(CUtensorMap& tensorMap, void* ptr) { + asm volatile( + "tensormap.replace.tile.global_address.global.b1024.b64 [%0], %1;\n" ::"l"(&tensorMap), + "l"(ptr) + : "memory"); +} + +__device__ inline void commitGroup() { + asm volatile("cp.async.bulk.commit_group;\n" : : : "memory"); +} + +// wait until only targetNbInFlightGroups groups are still in-flight. +template +__device__ inline void waitGroup() { + asm volatile("cp.async.bulk.wait_group %0;\n" ::"n"(targetNbInFlightGroups) : "memory"); +} + +__device__ inline void prefetchTensorMap(CUtensorMap const& tensorMap, + StateSpace loc = StateSpace::kGENERIC) { + assert(reinterpret_cast(&tensorMap) % alignof(CUtensorMap) == 0); + switch (loc) { + case StateSpace::kCONSTANT: + asm volatile("prefetch.const.tensormap [%0];\n" ::"l"(__cvta_generic_to_constant(&tensorMap)) + : "memory"); + break; + case StateSpace::kPARAMETER: + asm volatile( + "prefetch.param.tensormap [%0];\n" ::"l"(__cvta_generic_to_grid_constant(&tensorMap)) + : "memory"); + break; + case StateSpace::kGENERIC: + asm volatile("prefetch.tensormap [%0];\n" ::"l"(reinterpret_cast(&tensorMap)) + : "memory"); + break; + default: + asm volatile("trap;\n"); + } +} + +template +__device__ inline void storeAsync(void* dst, T const& src, CgaBarrier& bar) { + constexpr uint32_t nbWords = exactDiv(sizeof(T), sizeof(uint32_t)); + Vec const& srcVec = reinterpret_cast const&>(src); + if constexpr (nbWords == 1) { + asm volatile( + "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.u32 [%0], %1, [%2];\n" ::"l"( + __cvta_generic_to_shared(dst)), + "r"(srcVec[0]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else if constexpr (nbWords == 2) { + asm volatile( + "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.u32 [%0], {%1, %2}, " + "[%3];\n" ::"l"(__cvta_generic_to_shared(dst)), + "r"(srcVec[0]), "r"(srcVec[1]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else if constexpr (nbWords == 4) { + asm volatile( + "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v4.u32 [%0], {%1, %2, %3, %4}, " + "[%5];\n" ::"l"(__cvta_generic_to_shared(dst)), + "r"(srcVec[0]), "r"(srcVec[1]), "r"(srcVec[2]), "r"(srcVec[3]), + "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else { + static_assert(nbWords == 1 || nbWords == 2 || nbWords == 4, + "src size must be 4, 8 or 16 bytes"); + } +} + +} // namespace tma diff --git a/csrc/xqa/utils.cuh b/csrc/xqa/utils.cuh index 5883e5b834..2804d2b322 100644 --- a/csrc/xqa/utils.cuh +++ b/csrc/xqa/utils.cuh @@ -31,7 +31,13 @@ #include "barriers.cuh" inline constexpr float log2e = 1.4426950408889634; // std::log2(M_E) -inline constexpr float safeInitRowMax = -1e+30F; +// we used an optimization where exp(x-rowMax) is computed as: +/* bias = rowMax * log2e // shared for the whole row + exp(x-rowMax) = exp2f(x * log2e - bias) +*/ +// But this optimization is not numerically stable when (x * log2e - bias) is computed with FMA and +// x is too large. For this reason, don't set safeInitRowMax with a huge absolute value. +inline constexpr float safeInitRowMax = -1e+5F; inline constexpr int32_t kBAD_PAGE_INDEX = -1; __constant__ constexpr float kE4M3_MAX = 448.F; diff --git a/csrc/xqa/xqa_wrapper.cu b/csrc/xqa/xqa_wrapper.cu index a7bbbfaf0c..2be90d9d23 100644 --- a/csrc/xqa/xqa_wrapper.cu +++ b/csrc/xqa/xqa_wrapper.cu @@ -17,8 +17,8 @@ #include "../pytorch_extension_utils.h" #include "mha.h" -void xqa_wrapper(int64_t multiProcessorCount, int64_t nbKHeads, int64_t slidingWinSize, - double qScale, at::Tensor output, +void xqa_wrapper(bool run_fp8_mha, int64_t multiProcessorCount, int64_t nbKHeads, + int64_t slidingWinSize, double qScale, at::Tensor output, #if LOW_PREC_OUTPUT at::Tensor rcpOutScale, #endif @@ -33,21 +33,22 @@ void xqa_wrapper(int64_t multiProcessorCount, int64_t nbKHeads, int64_t slidingW float const* attentionSinksPtr = attentionSinks.defined() ? reinterpret_cast(attentionSinks.data_ptr()) : nullptr; + auto const mha_func = run_fp8_mha ? &launchHopperF8MHAFlashInfer : &launchMHAFlashInfer; - launchMHAFlashInfer(multiProcessorCount, nbKHeads, slidingWinSize, qScale, - reinterpret_cast(output.data_ptr()), + mha_func(multiProcessorCount, nbKHeads, slidingWinSize, qScale, + reinterpret_cast(output.data_ptr()), #if LOW_PREC_OUTPUT - reinterpret_cast(rcpOutScale.data_ptr()), + reinterpret_cast(rcpOutScale.data_ptr()), #endif - reinterpret_cast(q.data_ptr()), attentionSinksPtr, - reinterpret_cast(pool.data_ptr()), - reinterpret_cast(kvCachePageList.data_ptr()), - maxSeqLen, reinterpret_cast(seqLen.data_ptr()), batchSize, - reinterpret_cast(kvCacheScale.data_ptr()), + reinterpret_cast(q.data_ptr()), attentionSinksPtr, + reinterpret_cast(pool.data_ptr()), + reinterpret_cast(kvCachePageList.data_ptr()), maxSeqLen, + reinterpret_cast(seqLen.data_ptr()), batchSize, + reinterpret_cast(kvCacheScale.data_ptr()), #if SPEC_DEC - qSeqLen, reinterpret_cast(qCuSeqLens.data_ptr()), - reinterpret_cast(mask.data_ptr()), + qSeqLen, reinterpret_cast(qCuSeqLens.data_ptr()), + reinterpret_cast(mask.data_ptr()), #endif - reinterpret_cast(semaphores.data_ptr()), - reinterpret_cast(scratch.data_ptr()), stream); + reinterpret_cast(semaphores.data_ptr()), + reinterpret_cast(scratch.data_ptr()), stream); } diff --git a/flashinfer/aot.py b/flashinfer/aot.py index 54eb4f6c3d..630e9eee67 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -353,7 +353,8 @@ def gen_attention( def gen_xqa( - use_fp16_: List[bool], + fp16_input_: List[bool], + fp8_kv_cache_: List[bool], token_per_page_: List[int], head_size_: List[int], head_grp_size_: List[int], @@ -365,13 +366,15 @@ def gen_xqa( return # XQA requires SM90+ for ( - use_fp16, + fp16_input, + fp8_kv_cache, token_per_page, head_size, head_grp_size, use_sliding_window, ) in product( - use_fp16_, + fp16_input_, + fp8_kv_cache_, token_per_page_, head_size_, head_grp_size_, @@ -384,7 +387,8 @@ def gen_xqa( continue yield gen_xqa_module( - use_fp16=use_fp16, + fp16_input=fp16_input, + fp8_kv_cache=fp8_kv_cache, token_per_page=token_per_page, head_size=head_size, head_grp_size=head_grp_size, @@ -491,14 +495,16 @@ def gen_all_modules( if add_xqa: # Define XQA configurations to iterate over - xqa_use_fp16_ = [True, False] # fp16 and bf16 + xqa_fp16_input_ = [True, False] # fp16 and bf16 + xqa_fp8_kv_cache_ = [True, False] xqa_token_per_page_ = [16, 32, 64, 128] xqa_head_size_ = [64, 128, 256] xqa_head_grp_size_ = [1, 2, 4, 8] # Different group sizes for MQA/GQA jit_specs += list( gen_xqa( - xqa_use_fp16_, + xqa_fp16_input_, + xqa_fp8_kv_cache_, xqa_token_per_page_, xqa_head_size_, xqa_head_grp_size_, diff --git a/flashinfer/xqa.py b/flashinfer/xqa.py index fc34ef6c0b..9581a0f6b9 100644 --- a/flashinfer/xqa.py +++ b/flashinfer/xqa.py @@ -31,7 +31,7 @@ xqa_nvcc_flags = [ "-DNDEBUG=1", "-DBEAM_WIDTH=1", - "-DCACHE_ELEM_ENUM=0", + "-DUSE_INPUT_KV=0", "-DUSE_CUSTOM_BARRIER=1", "-DLOW_PREC_OUTPUT=0", "-DSPEC_DEC=0", @@ -39,16 +39,22 @@ def gen_xqa_module( - use_fp16: bool, + fp16_input: bool, + fp8_kv_cache: bool, token_per_page: int, head_size: int, head_grp_size: int, use_sliding_window: bool, ) -> JitSpec: - if use_fp16: - flag_use_fp16 = ["-DINPUT_FP16=1", "-DDTYPE=__half"] + if fp16_input: + flag_data_type = ["-DINPUT_FP16=1", "-DDTYPE=__half"] else: - flag_use_fp16 = ["-DINPUT_FP16=0", "-DDTYPE=__nv_bfloat16"] + flag_data_type = ["-DINPUT_FP16=0", "-DDTYPE=__nv_bfloat16"] + + if fp8_kv_cache: + flag_data_type.append("-DCACHE_ELEM_ENUM=2") + else: + flag_data_type.append("-DCACHE_ELEM_ENUM=0") if token_per_page not in [16, 32, 64, 128]: raise ValueError( @@ -70,9 +76,11 @@ def gen_xqa_module( flag_sliding_window = ["-DSLIDING_WINDOW=0"] return gen_jit_spec( - f"xqa_use_fp16_{use_fp16}_token_per_page_{token_per_page}_head_size_{head_size}_head_grp_size_{head_grp_size}_use_sliding_window_{use_sliding_window}", + f"xqa_fp16_input_{fp16_input}_fp8_kv_cache_{fp8_kv_cache}_token_per_page_{token_per_page}_head_size_{head_size}_head_grp_size_{head_grp_size}_use_sliding_window_{use_sliding_window}", [ jit_env.FLASHINFER_CSRC_DIR / "xqa/mha.cu", + jit_env.FLASHINFER_CSRC_DIR / "xqa/mha_sm90.cu", + jit_env.FLASHINFER_CSRC_DIR / "xqa/tensorMap.cpp", jit_env.FLASHINFER_CSRC_DIR / "xqa/xqa_wrapper.cu", jit_env.FLASHINFER_CSRC_DIR / "flashinfer_xqa_ops.cu", ], @@ -80,29 +88,37 @@ def gen_xqa_module( + sm90a_nvcc_flags + flag_tokens_per_page + flag_head_size - + flag_use_fp16 + + flag_data_type + flag_head_grp_size + flag_sliding_window, + extra_ldflags=["-lcuda"], # Add CUDA Driver API library ) @functools.cache def get_xqa_module( - use_fp16: bool, + fp16_input: bool, + fp8_kv_cache: bool, token_per_page: int, head_size: int, head_grp_size: int, use_sliding_window: bool, ): module = gen_xqa_module( - use_fp16, token_per_page, head_size, head_grp_size, use_sliding_window + fp16_input, + fp8_kv_cache, + token_per_page, + head_size, + head_grp_size, + use_sliding_window, ).build_and_load() @register_custom_op( - f"flashinfer::xqa_use_fp16_{use_fp16}_token_per_page_{token_per_page}_head_size_{head_size}_head_grp_size_{head_grp_size}_use_sliding_window_{use_sliding_window}", + f"flashinfer::xqa_fp16_input_{fp16_input}_fp8_kv_cache_{fp8_kv_cache}_token_per_page_{token_per_page}_head_size_{head_size}_head_grp_size_{head_grp_size}_use_sliding_window_{use_sliding_window}", mutates_args=("output", "scratch"), ) def xqa( + run_fp8_mha: bool, multiProcessorCount: int, nbKHeads: int, slidingWinSize: int, @@ -120,6 +136,7 @@ def xqa( scratch: torch.Tensor, ) -> None: module.xqa_wrapper.default( + run_fp8_mha, multiProcessorCount, nbKHeads, slidingWinSize, @@ -138,9 +155,10 @@ def xqa( ) @register_fake_op( - f"flashinfer::xqa_use_fp16_{use_fp16}_token_per_page_{token_per_page}_head_size_{head_size}_head_grp_size_{head_grp_size}_use_sliding_window_{use_sliding_window}" + f"flashinfer::xqa_fp16_input_{fp16_input}_fp8_kv_cache_{fp8_kv_cache}_token_per_page_{token_per_page}_head_size_{head_size}_head_grp_size_{head_grp_size}_use_sliding_window_{use_sliding_window}" ) def _fake_xqa( + run_fp8_mha: bool, multiProcessorCount: int, nbKHeads: int, slidingWinSize: int, @@ -165,7 +183,9 @@ def _fake_xqa( def xqa( - use_fp16: bool, + fp16_input: bool, + fp8_kv_cache: bool, + run_fp8_mha: bool, token_per_page: int, head_size: int, head_grp_size: int, @@ -189,9 +209,15 @@ def xqa( if get_compute_capability(torch.device(device="cuda"))[0] != 9: raise RuntimeError("XQA is only supported on SM90 GPUs") xqa_module = get_xqa_module( - use_fp16, token_per_page, head_size, head_grp_size, use_sliding_window + fp16_input, + fp8_kv_cache, + token_per_page, + head_size, + head_grp_size, + use_sliding_window, ) xqa_module.xqa( + run_fp8_mha, multiProcessorCount, nbKHeads, sliding_win_size if use_sliding_window else 0, diff --git a/tests/test_xqa.py b/tests/test_xqa.py index 2bdbb9e579..c7f2a58819 100644 --- a/tests/test_xqa.py +++ b/tests/test_xqa.py @@ -153,7 +153,9 @@ def ref_attention( reason="XQA is only supported on SM90 GPUs", ) @pytest.mark.parametrize("use_sliding_window", [True, False]) -@pytest.mark.parametrize("use_fp16", [True, False]) +@pytest.mark.parametrize("fp16_input", [True, False]) +@pytest.mark.parametrize("fp8_kv_cache", [True, False]) +@pytest.mark.parametrize("run_fp8_mha", [True, False]) @pytest.mark.parametrize("use_attention_sinks", [True, False]) @pytest.mark.parametrize("seq_len", [2, 15, 256, 514]) @pytest.mark.parametrize("batch_size", [1, 4]) @@ -166,7 +168,9 @@ def test_xqa( nb_k_heads, seq_len, tokens_per_page, - use_fp16, + fp16_input, + fp8_kv_cache, + run_fp8_mha, valid_elems_per_head, head_grp_size, use_attention_sinks, @@ -185,7 +189,7 @@ def test_xqa( beam_width, nb_q_heads, valid_elems_per_head, - dtype=torch.bfloat16 if not use_fp16 else torch.float16, + dtype=torch.bfloat16 if not fp16_input else torch.float16, device="cuda", ) output.fill_(float("nan")) @@ -194,7 +198,7 @@ def test_xqa( beam_width, nb_q_heads, valid_elems_per_head, - dtype=torch.bfloat16 if not use_fp16 else torch.float16, + dtype=torch.bfloat16 if not fp16_input else torch.float16, device="cuda", ) q_heads.normal_(0, 1) @@ -219,10 +223,14 @@ def test_xqa( cache_heads = torch.zeros( total_nb_cache_heads, valid_elems_per_head, - dtype=torch.bfloat16 if not use_fp16 else torch.float16, + dtype=torch.bfloat16 if not fp16_input else torch.float16, device="cuda", ) cache_heads.normal_(0, 1) + if fp8_kv_cache: + # Scale down the cache heads to keep values within the representable range of FP8 + # and prevent overflow during computation. The factor 4.0 is chosen empirically. + cache_heads /= 4.0 nb_pages_per_seq = div_up(max_seq_len, tokens_per_page) total_nb_pages = nb_pages_per_seq * 2 * beam_width * batch_size @@ -295,7 +303,9 @@ def cache_head_at( scratch_buf = torch.zeros(scratch_size, dtype=torch.uint8, device="cuda") xqa( - use_fp16, + fp16_input, + fp8_kv_cache, + run_fp8_mha, tokens_per_page, valid_elems_per_head, head_grp_size, @@ -307,7 +317,7 @@ def cache_head_at( output, q_heads, attention_sinks, - cache_heads, + cache_heads.to(torch.float8_e4m3fn) if fp8_kv_cache else cache_heads, page_list_arg, max_seq_len, seq_len_list, @@ -354,4 +364,21 @@ def cache_head_at( kernel_output = output[req][b][ idx_k_head * head_grp_size : (idx_k_head + 1) * head_grp_size ].to(torch.float32) - assert torch.allclose(ref_output, kernel_output, atol=0.01, rtol=0.01) + if fp8_kv_cache or run_fp8_mha: + atol = 0.05 + rtol = 0.05 + else: + atol = 0.01 + rtol = 0.01 + + diff_abs = torch.abs(ref_output - kernel_output) + diff_rel = diff_abs / (torch.abs(ref_output) + 1e-8) + + within_tolerance = (diff_abs <= atol) | (diff_rel <= rtol) + + pass_ratio = within_tolerance.float().mean().item() + + required_ratio = 0.99 + assert pass_ratio >= required_ratio, ( + f"Total {ref_output.numel()} elements, only {pass_ratio:.1%} meet tolerance criteria, require at least {required_ratio:.1%}" + )