Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions csrc/flashinfer_xqa_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

#include "pytorch_extension_utils.h"

void xqa_wrapper(int64_t multiProcessorCount, int64_t nbKHeads, int64_t slidingWinSize,
double qScale, at::Tensor output,
void xqa_wrapper(bool run_fp8_mha, int64_t multiProcessorCount, int64_t nbKHeads,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of making this a flag, could we pass a dtype?

Same for the other places where we pass:

  • the type of the input (only bf16 and fp16 supported I think)
  • the type of the kv-cache (fp8 or bf16)
  • the type in which we perform arithmetic (the same type as the kv-cache I think?)

int64_t slidingWinSize, double qScale, at::Tensor output,
#if LOW_PREC_OUTPUT
at::Tensor rcpOutScale,
#endif
Expand Down
145 changes: 145 additions & 0 deletions csrc/xqa/gmma.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
*
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
*/

#pragma once
#include "cuda_hint.cuh"
#include "mha_stdheaders.cuh"
#include "utils.cuh"
#ifndef __CUDACC__
#include <cuda_runtime.h>
#endif
#include <cuda_fp16.h>
#include <cuda_fp8.h>

namespace gmma {

enum class SwizzleMode : uint64_t { kNONE = 0, k128 = 1, k64 = 2, k32 = 3 };

struct MatDesc {
uint64_t addr : 16;
uint64_t dimKOffset : 16;
uint64_t dimMNOffset : 16;
uint64_t pad0 : 1;
uint64_t baseOffset : 3;
uint64_t pad1 : 10;
SwizzleMode swizzle : 2;

enum class Raw : uint64_t {};

[[nodiscard]] __device__ inline MatDesc withAddr(void const* data) const {
MatDesc ret = *this;
ret.addr = encode(__cvta_generic_to_shared(data));
return ret;
}

static __device__ inline uint32_t encode(uint32_t val) { return (val & 0x3FFFFU) >> 4; }

__device__ inline bool operator==(MatDesc const& other) const { return raw() == other.raw(); }

__device__ inline Raw const& raw() const {
static_assert(sizeof(MatDesc) == 8);
return reinterpret_cast<Raw const&>(*this);
}

static __device__ inline MatDesc fromRaw(Raw const& raw) {
return reinterpret_cast<MatDesc const&>(raw);
}
};

static_assert(sizeof(MatDesc) == 8);

[[nodiscard]] __device__ inline MatDesc::Raw addAddr(MatDesc::Raw base, void const* data) {
assert((uint32_t(__cvta_generic_to_shared(data)) & ~0x3FFFFU) == 0);
MatDesc::Raw ret = base;
auto& u32x2 = reinterpret_cast<uint32_t(&)[2]>(ret);
u32x2[0] += static_cast<uint32_t>(__cvta_generic_to_shared(data)) >> 4;
return ret;
}

__device__ inline MatDesc makeMatDesc(void const* data, uint32_t dimKByteOffset,
uint32_t dimMNByteOffset, void const* patternStartAddr,
SwizzleMode swizzleMode) {
uint32_t const patternAddr = __cvta_generic_to_shared(patternStartAddr);
uint32_t const baseAlign = [&]() -> uint32_t {
switch (swizzleMode) {
case SwizzleMode::kNONE:
return 1;
case SwizzleMode::k128:
return 1024;
case SwizzleMode::k64:
return 512;
case SwizzleMode::k32:
return 256;
}
asm volatile("trap;\n");
return 0;
}();
uint32_t const baseOffset = ((patternAddr % baseAlign == 0) ? 0U : ((patternAddr >> 0x7) & 0x7));
return MatDesc{
/*addr=*/MatDesc::encode(__cvta_generic_to_shared(data)),
/*dimKOffset=*/MatDesc::encode(dimKByteOffset),
/*dimMNOffset=*/MatDesc::encode(dimMNByteOffset),
/*pad0=*/0,
/*baseOffset=*/baseOffset,
/*pad1=*/0,
/*swizzle=*/swizzleMode,
};
}

__device__ inline MatDesc makeMatDesc(void const* data, uint32_t dimKByteOffset,
uint32_t dimMNByteOffset, SwizzleMode swizzleMode) {
return makeMatDesc(data, dimKByteOffset, dimMNByteOffset, data, swizzleMode);
}

inline constexpr uint32_t instM = 64;

template <typename MathElem>
inline constexpr uint32_t instK = 32 / sizeof(MathElem);

inline constexpr uint32_t instNBase = 8;

// for both a and b, outer-dim is gemm-K and inner-dim is gemm-M or gemm-N
// acc is used as both input and output.
template <typename InputElem, uint32_t n, bool transA = false, bool transB = false>
__device__ void mma_async_shmA(float (&acc)[exactDiv(n, instNBase)][2][2], MatDesc::Raw descA,
MatDesc::Raw descB, bool accHasVal);
template <typename InputElem, uint32_t n, bool transA = false, bool transB = false>
__device__ void mma_async_regA(float (&acc)[exactDiv(n, instNBase)][2][2],
uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal);

__device__ inline void fence() { asm volatile("wgmma.fence.sync.aligned;\n"); }

__device__ inline void commit_group() { asm volatile("wgmma.commit_group.sync.aligned;\n"); }

template <uint32_t targetNbInFlightGroups>
__device__ inline void wait_group() {
asm volatile("wgmma.wait_group.sync.aligned %0\n; " ::"n"(targetNbInFlightGroups));
}

template <bool swizzle, typename T, uint32_t rows, uint32_t cols, bool alignedForSwizzle>
constexpr SwizzleMode getSwizzleMode(Array2D<T, rows, cols, alignedForSwizzle> const&) {
constexpr auto rowBytes = Array2D<T, rows, cols, alignedForSwizzle>::rowBytes;
if constexpr (!swizzle) {
return SwizzleMode::kNONE;
}
if constexpr (rowBytes % 128 == 0) {
return SwizzleMode::k128;
} else if constexpr (rowBytes == 64) {
return SwizzleMode::k64;
} else {
static_assert(rowBytes == 32);
return SwizzleMode::k32;
}
}
} // namespace gmma

#include "gmma_impl.cuh"
Loading