Skip to content

Commit 3a7c8b3

Browse files
authored
[Cuda] support 8 bits in MatMulNBits (microsoft#24509)
### Description Support 8 bits in MatMulNBits cuda kernel. The `MatMulFloat8bKernel` CUDA kernel performs a matrix-vector multiplication (GEMM) where the matrix B is quantized per block using 8-bit integers. The kernel computes $Output = A \times B$, where: * $A$ is a row vector (shape `[M, K]`) of type `T` (`float` or `half`). * $B$ is a matrix (shape `[K, N]`) quantized using 8-bit unsigned integers (`uint8_t`) with a block structure. It's stored as `[N, K/block_size, block_size]`. * `scales_data` contains the dequantization scales (shape `[N, K/block_size]`). * `zero_points` contains the dequantization zero points (shape `[N, K/block_size]`), if used (`has_zero_point` is true). * `output` is the resulting row vector (shape `[M, N]`). The kernel uses a thread block structure of `(kWarpSize, kColsPerThreadBlock)`, meaning each block handles `kColsPerThreadBlock` (which is 8) columns of the output. Each warp within the block is responsible for one output element (`[m_id, n_id]`). Threads within a warp cooperate to compute the dot product along the K dimension. Each thread (`lane_id`) handles `kElementsPerThreadPerIteration` (which is 8) elements of the K dimension in each step. Here's a breakdown of the three algorithms (`kKernelAlgo`): 1. **`kKernelAlgo = 0` (Unrolling):** * **Strategy:** This algorithm processes the K dimension by iterating in large steps (`k_per_iter = kWarpSize * kElementsPerThreadPerIteration = 32 * 8 = 256`). Inside the main loop, it uses a macro (`UnRollReduction`) with `#pragma unroll` directives to aggressively unroll the innermost computations. It tries unrolling factors of 16, 4, and 1 sequentially to cover as much of the K dimension as possible with unrolled code. * **Pros:** Can significantly reduce loop overhead (branching instructions, counter updates) and expose more instruction-level parallelism, potentially hiding memory latency. * **Cons:** Can lead to a large increase in compiled code size (register pressure, potential instruction cache misses). The effectiveness heavily depends on the compiler and the specific GPU architecture. The multi-stage unrolling adds complexity. It requires `k_per_iter` to be a multiple of `block_size` for correct scale/zp indexing within the unrolled loop. * **Performance Expectation:** Potentially the highest performance *if* the unrolling is effective on the target hardware and doesn't cause resource issues (registers, cache). Often good for compute-bound or latency-bound scenarios where loop overhead is a bottleneck. 2. **`kKernelAlgo = 1` (Simple Loop):** * **Strategy:** This algorithm also iterates along the K dimension in steps of `k_per_iter` (256), but uses a simple `for` loop without explicit `#pragma unroll`. It relies on the compiler's default loop optimization capabilities. * **Pros:** Simpler code, smaller code size compared to Algorithm 0. Less likely to cause register pressure or instruction cache issues. Easier for the compiler to reason about. * **Cons:** May incur higher loop overhead compared to effective unrolling. Performance might be lower if loop overhead is significant. * **Performance Expectation:** A solid baseline. Might be close to Algorithm 0 if the compiler performs implicit unrolling effectively, or faster if Algorithm 0 suffers from code bloat penalties. 3. **`kKernelAlgo = 2` (Block Size Iteration):** * **Strategy:** This algorithm changes the iteration strategy fundamentally. Instead of iterating in fixed steps of `k_per_iter`, it iterates based on the quantization `block_size`. The outer loop runs `blocks_per_K` (`K / block_size`) times. Inside this loop, the scale and zero point for the *entire block* are fetched once per warp. Then, each thread checks if its assigned K-elements (`lane_offset`) fall within the current `block_size` chunk and processes them using the fetched scale/zp. * **Pros:** Directly aligns with the block quantization data structure. Fetches scale/zero-point values less frequently (once per `block_size` chunk per warp), potentially reducing shared memory bank conflicts or register usage compared to calculating the index (`current_meta_k`) in every inner step as in Algo 0/1. Might have better memory access patterns for scale/zp data. * **Cons:** The outer loop iterates `K / block_size` times. If `block_size` is small (e.g., 16, 32), this could be many iterations. The logic inside the loop (`if (current_k_base < k_end_block ...)`) adds conditional execution. * **Performance Expectation:** Performance depends heavily on the `block_size`. If `block_size` is large (e.g., 128, 256), the number of outer loop iterations is small, and the efficiency gain from fetching scale/zp once per block might outweigh the overhead. If `block_size` is small, the overhead of the outer loop might dominate. **Next Step:** 1. **Profile:** The most reliable way is to benchmark all three algorithms (`kKernelAlgo = 0, 1, 2`) on your target GPU hardware with representative input sizes (`N`, `K`), data types (`T`), and `block_size` values. Use profiling tools like NVIDIA Nsight Compute to analyze performance metrics (execution time, occupancy, instruction throughput, memory bandwidth, cache hit rates, register spills). 2. **Hypothesize based on `block_size`:** * For **large `block_size`** (e.g., 128, 256), Algorithm 2 might be competitive or even the best due to efficient scale/ZP handling. Algorithm 0 could also be very fast. * For **small `block_size`** (e.g., 16, 32), Algorithm 0 (unroll) or Algorithm 1 (simple loop) might outperform Algorithm 2 due to lower loop overhead in the K dimension. 3. Compare performance with TRT LLM FpA IntB GEMM. ### Motivation and Context 4 bits has accuracy loss for some LLM, need more bits for some layers.
1 parent 1e118d6 commit 3a7c8b3

File tree

9 files changed

+1489
-201
lines changed

9 files changed

+1489
-201
lines changed

onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,43 @@
77
namespace onnxruntime {
88
namespace contrib {
99
namespace cuda {
10+
11+
///////////////////////////////////////////////////////////////////////////////
12+
// A more general block-wise dequantization implementation that supports
13+
// different block sizes and block orientations (row-wise/column-wise).
14+
template <
15+
int Row_, ///< rows of a matrix
16+
int Column_ ///< columns of a matrix
17+
>
18+
struct Shape2D {
19+
static int const kRow = Row_; ///< rows of a matrix
20+
static int const kColumn = Column_; ///< columns of a matrix
21+
static int const kCount = Row_ * Column_; ///< total number of elements in a matrix
22+
};
23+
24+
/**
25+
* @brief Blockwise quantization constants
26+
* @tparam ElementT source data type, e.g. fp32/fp16
27+
* @tparam block_size number of elemenets quantized together
28+
* @tparam qbits number of bits in each quantized element
29+
* @tparam Columnwise true: elements in a block come from one single column
30+
* false: elements in a block come from one single row
31+
*/
32+
template <
33+
typename ElementT,
34+
int32_t block_size,
35+
int32_t qbits,
36+
bool Columnwise>
37+
struct BlkQuantTraits {
38+
// number of qbit elements to pack into whole bytes
39+
static constexpr int kPackSize = (qbits == 8) ? 1 : (qbits == 4) ? 2 : (qbits == 2) ? 4 : 0;
40+
static_assert(kPackSize != 0, "Packing to whole bytes not supported for this qbits!");
41+
42+
using QuantBlk = std::conditional_t<Columnwise, Shape2D<block_size, 1>, Shape2D<1, block_size>>;
43+
44+
using ThreadBlk = Shape2D<QuantBlk::kRow * kPackSize, QuantBlk::kColumn>;
45+
};
46+
1047
template <class T, typename ZeroT>
1148
Status Dequantize4Bits(
1249
T* output,
@@ -19,6 +56,18 @@ Status Dequantize4Bits(
1956
int block_size,
2057
cudaStream_t stream);
2158

59+
template <class T, typename ZeroT>
60+
Status Dequantize8Bits(
61+
T* output,
62+
const uint8_t* quant_data,
63+
const T* scales_data,
64+
const ZeroT* zero_points,
65+
const int32_t* reorder_idx,
66+
int k,
67+
int n,
68+
int block_size,
69+
cudaStream_t stream);
70+
2271
/**
2372
* @brief Dequantize a block-wise quantized matrix, and store the result in a
2473
* column major matrix for use in subsequent GEMM. This implementation supports
@@ -45,6 +94,17 @@ Status DequantizeBlockwise4b(
4594
int columns,
4695
cudaStream_t stream);
4796

97+
template <typename T>
98+
Status DequantizeBlockwise8b(
99+
T* dst,
100+
const uint8_t* qelements,
101+
const T* scales,
102+
const uint8_t* zero_points,
103+
int block_size,
104+
bool columnwise,
105+
int rows,
106+
int columns,
107+
cudaStream_t stream);
48108
} // namespace cuda
49109
} // namespace contrib
50110
} // namespace onnxruntime

0 commit comments

Comments
 (0)