Skip to content

Commit 3b3590b

Browse files
authored
[CUDA] Use fp16 accumulation for 4-bit quant in GEMV (ml-explore#3197)
1 parent 3c56543 commit 3b3590b

File tree

1 file changed

+35
-21
lines changed
  • mlx/backend/cuda/quantized/qmm

1 file changed

+35
-21
lines changed

mlx/backend/cuda/quantized/qmm/qmv.cu

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,36 @@ namespace cg = cooperative_groups;
2020
// out = fma(x, w_dq, out)
2121
template <int N, typename T, typename Q>
2222
__device__ __forceinline__ void
23+
dequant_fma(const T* x, const Q* w, T scale, T bias, T* out) {
24+
// Read x/w into registers.
25+
auto x_vec = *(reinterpret_cast<const cutlass::AlignedArray<T, N>*>(x));
26+
auto w_vec = *(reinterpret_cast<const cutlass::AlignedArray<Q, N>*>(w));
27+
// Output is assumed to be registers.
28+
auto* out_vec = reinterpret_cast<cutlass::Array<T, N>*>(out);
29+
30+
// Dequantize w.
31+
cutlass::NumericArrayConverter<T, Q, N> converter_tq;
32+
cutlass::Array<T, N> w_dq = converter_tq(w_vec);
33+
if constexpr (cuda::std::is_same_v<T, float>) {
34+
#pragma unroll
35+
for (int i = 0; i < N; ++i) {
36+
w_dq[i] = w_dq[i] * scale + bias;
37+
}
38+
} else {
39+
w_dq = w_dq * scale + bias;
40+
}
41+
42+
// Multiply and add.
43+
*out_vec = cutlass::fma(x_vec, w_dq, *out_vec);
44+
}
45+
46+
// Specialization for doing float32 accumulations on narrow types.
47+
template <
48+
int N,
49+
typename T,
50+
typename Q,
51+
typename = cuda::std::enable_if_t<!cuda::std::is_same_v<T, float>>>
52+
__device__ __forceinline__ void
2353
dequant_fma(const T* x, const Q* w, T scale, T bias, float* out) {
2454
// Read x/w into registers.
2555
auto x_vec = *(reinterpret_cast<const cutlass::AlignedArray<T, N>*>(x));
@@ -42,24 +72,6 @@ dequant_fma(const T* x, const Q* w, T scale, T bias, float* out) {
4272
*out_vec = cutlass::fma(x_f, w_f, *out_vec);
4373
}
4474

45-
// Specialized for float which does not need promotions.
46-
template <int N, typename Q>
47-
__device__ __forceinline__ void
48-
dequant_fma(const float* x, const Q* w, float scale, float bias, float* out) {
49-
auto x_vec = *(reinterpret_cast<const cutlass::AlignedArray<float, N>*>(x));
50-
auto w_vec = *(reinterpret_cast<const cutlass::AlignedArray<Q, N>*>(w));
51-
auto* out_vec = reinterpret_cast<cutlass::Array<float, N>*>(out);
52-
53-
cutlass::NumericArrayConverter<float, Q, N> converter;
54-
cutlass::Array<float, N> w_dq = converter(w_vec);
55-
#pragma unroll
56-
for (int i = 0; i < N; ++i) {
57-
w_dq[i] = w_dq[i] * scale + bias;
58-
}
59-
60-
*out_vec = cutlass::fma(x_vec, w_dq, *out_vec);
61-
}
62-
6375
template <
6476
int rows_per_block,
6577
int elems_per_thread,
@@ -91,7 +103,8 @@ __global__ void qmv_kernel(
91103

92104
// For sub-byte Q, pointer moves by 8bits for each advance, e.g. w += 1 would
93105
// move past 2 elements for 4-bit Q.
94-
constexpr int w_step = 8 / cuda::std::min(8, cute::sizeof_bits_v<Q>);
106+
constexpr int bits = cute::sizeof_bits_v<Q>;
107+
constexpr int w_step = 8 / cuda::std::min(8, bits);
95108

96109
// How many groups (and scales/biases) in a row.
97110
int groups_per_row = k / group_size;
@@ -104,7 +117,7 @@ __global__ void qmv_kernel(
104117
}
105118

106119
// Accumulations of current row.
107-
float sums[elems_per_thread] = {};
120+
cuda::std::conditional_t<(bits >= 8), float, T> sums[elems_per_thread] = {};
108121

109122
auto dequant_fma_tile = [&](int idx) {
110123
T scale = scales[idx / group_size];
@@ -157,7 +170,8 @@ void qmv(
157170
int k,
158171
F&& launch_kernel) {
159172
constexpr int rows_per_block = 8;
160-
constexpr int elems_per_thread = 8;
173+
constexpr int elems_per_thread =
174+
(cute::sizeof_bits_v<T> <= 16 && cute::sizeof_bits_v<Q> <= 4) ? 16 : 8;
161175

162176
dim3 num_blocks{uint32_t(cuda::ceil_div(n, rows_per_block)), uint32_t(m)};
163177
dim3 block_dims{WARP_SIZE, rows_per_block};

0 commit comments

Comments
 (0)