-
Notifications
You must be signed in to change notification settings - Fork 532
add xqa fp8 mha and fp8 kv cache #1769
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
qsang-nv
wants to merge
3
commits into
flashinfer-ai:main
Choose a base branch
from
qsang-nv:xqa_fp8
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+8,948
β50
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
/* | ||
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
* SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement | ||
* | ||
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual | ||
* property and proprietary rights in and to this material, related | ||
* documentation and any modifications thereto. Any use, reproduction, | ||
* disclosure or distribution of this material and related documentation | ||
* without an express license agreement from NVIDIA CORPORATION or | ||
* its affiliates is strictly prohibited. | ||
*/ | ||
|
||
#pragma once | ||
#include "cuda_hint.cuh" | ||
#include "mha_stdheaders.cuh" | ||
#include "utils.cuh" | ||
#ifndef __CUDACC__ | ||
#include <cuda_runtime.h> | ||
#endif | ||
#include <cuda_fp16.h> | ||
#include <cuda_fp8.h> | ||
|
||
namespace gmma { | ||
|
||
enum class SwizzleMode : uint64_t { kNONE = 0, k128 = 1, k64 = 2, k32 = 3 }; | ||
|
||
struct MatDesc { | ||
uint64_t addr : 16; | ||
uint64_t dimKOffset : 16; | ||
uint64_t dimMNOffset : 16; | ||
uint64_t pad0 : 1; | ||
uint64_t baseOffset : 3; | ||
uint64_t pad1 : 10; | ||
SwizzleMode swizzle : 2; | ||
|
||
enum class Raw : uint64_t {}; | ||
|
||
[[nodiscard]] __device__ inline MatDesc withAddr(void const* data) const { | ||
MatDesc ret = *this; | ||
ret.addr = encode(__cvta_generic_to_shared(data)); | ||
return ret; | ||
} | ||
|
||
static __device__ inline uint32_t encode(uint32_t val) { return (val & 0x3FFFFU) >> 4; } | ||
|
||
__device__ inline bool operator==(MatDesc const& other) const { return raw() == other.raw(); } | ||
|
||
__device__ inline Raw const& raw() const { | ||
static_assert(sizeof(MatDesc) == 8); | ||
return reinterpret_cast<Raw const&>(*this); | ||
} | ||
|
||
static __device__ inline MatDesc fromRaw(Raw const& raw) { | ||
return reinterpret_cast<MatDesc const&>(raw); | ||
} | ||
}; | ||
|
||
static_assert(sizeof(MatDesc) == 8); | ||
|
||
[[nodiscard]] __device__ inline MatDesc::Raw addAddr(MatDesc::Raw base, void const* data) { | ||
assert((uint32_t(__cvta_generic_to_shared(data)) & ~0x3FFFFU) == 0); | ||
MatDesc::Raw ret = base; | ||
auto& u32x2 = reinterpret_cast<uint32_t(&)[2]>(ret); | ||
u32x2[0] += static_cast<uint32_t>(__cvta_generic_to_shared(data)) >> 4; | ||
return ret; | ||
} | ||
|
||
__device__ inline MatDesc makeMatDesc(void const* data, uint32_t dimKByteOffset, | ||
uint32_t dimMNByteOffset, void const* patternStartAddr, | ||
SwizzleMode swizzleMode) { | ||
uint32_t const patternAddr = __cvta_generic_to_shared(patternStartAddr); | ||
uint32_t const baseAlign = [&]() -> uint32_t { | ||
switch (swizzleMode) { | ||
case SwizzleMode::kNONE: | ||
return 1; | ||
case SwizzleMode::k128: | ||
return 1024; | ||
case SwizzleMode::k64: | ||
return 512; | ||
case SwizzleMode::k32: | ||
return 256; | ||
} | ||
asm volatile("trap;\n"); | ||
return 0; | ||
}(); | ||
uint32_t const baseOffset = ((patternAddr % baseAlign == 0) ? 0U : ((patternAddr >> 0x7) & 0x7)); | ||
return MatDesc{ | ||
/*addr=*/MatDesc::encode(__cvta_generic_to_shared(data)), | ||
/*dimKOffset=*/MatDesc::encode(dimKByteOffset), | ||
/*dimMNOffset=*/MatDesc::encode(dimMNByteOffset), | ||
/*pad0=*/0, | ||
/*baseOffset=*/baseOffset, | ||
/*pad1=*/0, | ||
/*swizzle=*/swizzleMode, | ||
}; | ||
} | ||
|
||
__device__ inline MatDesc makeMatDesc(void const* data, uint32_t dimKByteOffset, | ||
uint32_t dimMNByteOffset, SwizzleMode swizzleMode) { | ||
return makeMatDesc(data, dimKByteOffset, dimMNByteOffset, data, swizzleMode); | ||
} | ||
|
||
inline constexpr uint32_t instM = 64; | ||
|
||
template <typename MathElem> | ||
inline constexpr uint32_t instK = 32 / sizeof(MathElem); | ||
|
||
inline constexpr uint32_t instNBase = 8; | ||
|
||
// for both a and b, outer-dim is gemm-K and inner-dim is gemm-M or gemm-N | ||
// acc is used as both input and output. | ||
template <typename InputElem, uint32_t n, bool transA = false, bool transB = false> | ||
__device__ void mma_async_shmA(float (&acc)[exactDiv(n, instNBase)][2][2], MatDesc::Raw descA, | ||
MatDesc::Raw descB, bool accHasVal); | ||
template <typename InputElem, uint32_t n, bool transA = false, bool transB = false> | ||
__device__ void mma_async_regA(float (&acc)[exactDiv(n, instNBase)][2][2], | ||
uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal); | ||
|
||
__device__ inline void fence() { asm volatile("wgmma.fence.sync.aligned;\n"); } | ||
|
||
__device__ inline void commit_group() { asm volatile("wgmma.commit_group.sync.aligned;\n"); } | ||
|
||
template <uint32_t targetNbInFlightGroups> | ||
__device__ inline void wait_group() { | ||
asm volatile("wgmma.wait_group.sync.aligned %0\n; " ::"n"(targetNbInFlightGroups)); | ||
} | ||
|
||
template <bool swizzle, typename T, uint32_t rows, uint32_t cols, bool alignedForSwizzle> | ||
constexpr SwizzleMode getSwizzleMode(Array2D<T, rows, cols, alignedForSwizzle> const&) { | ||
constexpr auto rowBytes = Array2D<T, rows, cols, alignedForSwizzle>::rowBytes; | ||
if constexpr (!swizzle) { | ||
return SwizzleMode::kNONE; | ||
} | ||
if constexpr (rowBytes % 128 == 0) { | ||
return SwizzleMode::k128; | ||
} else if constexpr (rowBytes == 64) { | ||
return SwizzleMode::k64; | ||
} else { | ||
static_assert(rowBytes == 32); | ||
return SwizzleMode::k32; | ||
} | ||
} | ||
} // namespace gmma | ||
|
||
#include "gmma_impl.cuh" |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of making this a flag, could we pass a dtype?
Same for the other places where we pass: