Skip to content

Commit 3c56543

Browse files
authored
[CUDA] Quantized GEMV (ml-explore#3180)
1 parent 9eef9f1 commit 3c56543

File tree

10 files changed

+572
-120
lines changed

10 files changed

+572
-120
lines changed

mlx/backend/cuda/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ target_sources(
5656
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
5757
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
5858
${CMAKE_CURRENT_SOURCE_DIR}/quantized/fp_quantize.cu
59-
${CMAKE_CURRENT_SOURCE_DIR}/quantized/qmv.cu
6059
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
6160
${CMAKE_CURRENT_SOURCE_DIR}/quantized/qqmm.cpp
6261
${CMAKE_CURRENT_SOURCE_DIR}/quantized/qqmm_utils.cu

mlx/backend/cuda/quantized/qmm/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
target_sources(
22
mlx
33
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/qmm.cpp
4+
${CMAKE_CURRENT_SOURCE_DIR}/qmv.cu
5+
${CMAKE_CURRENT_SOURCE_DIR}/fp_qmv.cu
46
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm90_m128_n16_m1.cu
57
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm90_m128_n32_m1.cu
68
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm90_m128_n64_m2.cu
Lines changed: 61 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,21 @@
22

33
#include "mlx/backend/cuda/device/utils.cuh"
44
#include "mlx/backend/cuda/kernel_utils.cuh"
5-
#include "mlx/backend/cuda/quantized/qmv.h"
5+
#include "mlx/backend/cuda/quantized/qmm/qmm.h"
66
#include "mlx/backend/cuda/quantized/quantized_utils.cuh"
77
#include "mlx/backend/cuda/quantized/quantized_utils.h"
88
#include "mlx/dtype_utils.h"
99

1010
#include <cooperative_groups.h>
1111
#include <cooperative_groups/reduce.h>
1212

13-
namespace mlx::core::cu {
13+
namespace mlx::core {
1414

15-
namespace cg = cooperative_groups;
15+
constexpr int rows_per_block = 8;
16+
17+
namespace cu {
1618

17-
static constexpr int rows_per_block = 8;
19+
namespace cg = cooperative_groups;
1820

1921
template <typename T>
2022
__device__ void adjust_matrix_offsets(
@@ -199,6 +201,8 @@ __global__ void fp_qmv_batched(
199201
mat, scales, vec, out, rows, cols);
200202
}
201203

204+
} // namespace cu
205+
202206
template <typename F>
203207
void dispatch_1_2_4(int n, F&& f) {
204208
switch (n) {
@@ -221,11 +225,13 @@ void fp_qmv(
221225
array& out,
222226
int bits,
223227
int group_size,
224-
int M,
225-
int N,
226-
int K,
227-
CommandEncoder& encoder,
228+
cu::CommandEncoder& encoder,
228229
Stream s) {
230+
uint32_t M = x.shape(-2);
231+
uint32_t N = out.shape(-1);
232+
uint32_t K = x.shape(-1);
233+
uint32_t B = out.size() / (M * N);
234+
229235
// Make sure the last two dims of x and w, s, b are contiguous. This should
230236
// be relaxed for x.
231237
array vec = ensure_row_contiguous_matrix(x, encoder, s);
@@ -240,7 +246,6 @@ void fp_qmv(
240246
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
241247
if constexpr (!std::is_same_v<T, double>) {
242248
dim3 block_dims{WARP_SIZE, rows_per_block};
243-
uint32_t B = out.size() / (M * N);
244249
uint32_t blocks_y = (N + rows_per_block - 1) / rows_per_block;
245250
const uint32_t* mat_ptr = gpu_ptr<uint32_t>(mat);
246251
const T* vec_ptr = gpu_ptr<T>(vec);
@@ -256,55 +261,56 @@ void fp_qmv(
256261
n = 2;
257262
}
258263
dispatch_1_2_4(n, [&](auto n) {
259-
dispatch_bool(B > 1, [&](auto batched) {
260-
if (!batched.value) {
261-
auto kernel =
262-
fp_qmv_single<T, rows_per_block, n.value, 4, 32, true>;
263-
if (bits == 8) {
264-
kernel = fp_qmv_single<T, rows_per_block, n.value, 8, 32, true>;
265-
} else if (group_size == 16) {
266-
kernel = fp_qmv_single<T, rows_per_block, n.value, 4, 16, false>;
267-
}
268-
encoder.add_kernel_node(
269-
kernel,
270-
{static_cast<uint32_t>(M), blocks_y},
271-
block_dims,
272-
mat_ptr,
273-
gpu_ptr<uint8_t>(scales),
274-
vec_ptr,
275-
gpu_ptr<T>(out),
276-
N,
277-
K);
278-
} else {
279-
auto kernel =
280-
fp_qmv_batched<T, rows_per_block, n.value, 4, 32, true>;
281-
if (bits == 8) {
282-
kernel = fp_qmv_batched<T, rows_per_block, n.value, 8, 32, true>;
283-
} else if (group_size == 16) {
284-
kernel = fp_qmv_batched<T, rows_per_block, n.value, 4, 16, false>;
285-
}
286-
encoder.add_kernel_node(
287-
kernel,
288-
{static_cast<uint32_t>(M), blocks_y, B},
289-
block_dims,
290-
mat_ptr,
291-
gpu_ptr<uint8_t>(scales),
292-
vec_ptr,
293-
gpu_ptr<T>(out),
294-
N,
295-
K,
296-
vec.ndim() - 2,
297-
const_param(vec.shape()),
298-
const_param(vec.strides()),
299-
mat.ndim() - 2,
300-
const_param(mat.shape()),
301-
const_param(mat.strides()),
302-
const_param(scales.strides()));
264+
if (B == 1) {
265+
auto kernel =
266+
cu::fp_qmv_single<T, rows_per_block, n.value, 4, 32, true>;
267+
if (bits == 8) {
268+
kernel = cu::fp_qmv_single<T, rows_per_block, n.value, 8, 32, true>;
269+
} else if (group_size == 16) {
270+
kernel =
271+
cu::fp_qmv_single<T, rows_per_block, n.value, 4, 16, false>;
303272
}
304-
});
273+
encoder.add_kernel_node(
274+
kernel,
275+
{uint32_t(x.size() / K), blocks_y},
276+
block_dims,
277+
mat_ptr,
278+
gpu_ptr<uint8_t>(scales),
279+
vec_ptr,
280+
gpu_ptr<T>(out),
281+
N,
282+
K);
283+
} else {
284+
auto kernel =
285+
cu::fp_qmv_batched<T, rows_per_block, n.value, 4, 32, true>;
286+
if (bits == 8) {
287+
kernel =
288+
cu::fp_qmv_batched<T, rows_per_block, n.value, 8, 32, true>;
289+
} else if (group_size == 16) {
290+
kernel =
291+
cu::fp_qmv_batched<T, rows_per_block, n.value, 4, 16, false>;
292+
}
293+
encoder.add_kernel_node(
294+
kernel,
295+
{M, blocks_y, B},
296+
block_dims,
297+
mat_ptr,
298+
gpu_ptr<uint8_t>(scales),
299+
vec_ptr,
300+
gpu_ptr<T>(out),
301+
N,
302+
K,
303+
vec.ndim() - 2,
304+
const_param(vec.shape()),
305+
const_param(vec.strides()),
306+
mat.ndim() - 2,
307+
const_param(mat.shape()),
308+
const_param(mat.strides()),
309+
const_param(scales.strides()));
310+
}
305311
});
306312
}
307313
});
308314
}
309315

310-
} // namespace mlx::core::cu
316+
} // namespace mlx::core

mlx/backend/cuda/quantized/qmm/qmm.cpp

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,46 @@ void qmm_impl_sm90(
2121
Stream s);
2222
#endif // defined(MLX_CUDA_SM90A_ENABLED)
2323

24+
bool supports_qmm_sm90(
25+
const array& x,
26+
const array& w,
27+
const array& scales,
28+
const std::optional<array>& biases,
29+
const array& out,
30+
bool transpose,
31+
int bits,
32+
int group_size,
33+
QuantizationMode mode,
34+
cu::Device& device) {
35+
if (device.compute_capability_major() != 9) {
36+
return false;
37+
}
38+
int k = x.shape(-1);
39+
if (k % 64 != 0) {
40+
return false;
41+
}
42+
if (!biases) {
43+
return false;
44+
}
45+
if (!x.flags().row_contiguous || !w.flags().row_contiguous ||
46+
!scales.flags().row_contiguous || !biases->flags().row_contiguous) {
47+
return false;
48+
}
49+
if (!transpose) {
50+
return false;
51+
}
52+
if (bits % 2 != 0) {
53+
return false;
54+
}
55+
if (group_size < k) {
56+
return false;
57+
}
58+
if (mode != QuantizationMode::Affine) {
59+
return false;
60+
}
61+
return true;
62+
}
63+
2464
void qmm_sm90(
2565
const array& x,
2666
const array& w,
@@ -57,4 +97,71 @@ void qmm_sm90(
5797
#endif // defined(MLX_CUDA_SM90A_ENABLED)
5898
}
5999

100+
bool supports_fp_qmv(
101+
const array& x,
102+
const array& w,
103+
const array& scales,
104+
const std::optional<array>& biases,
105+
const array& out,
106+
bool transpose,
107+
int bits,
108+
int group_size,
109+
QuantizationMode mode,
110+
cu::Device& device) {
111+
bool non_batched = w.ndim() == 2;
112+
int k = x.shape(-1);
113+
int n = out.shape(-1);
114+
int vec_batch = non_batched ? x.size() / k : x.shape(-2);
115+
if (vec_batch > 8) {
116+
return false;
117+
}
118+
if (!transpose) {
119+
return false;
120+
}
121+
if (mode == QuantizationMode::Affine) {
122+
return false;
123+
}
124+
return true;
125+
}
126+
127+
bool supports_qmv(
128+
const array& x,
129+
const array& w,
130+
const array& scales,
131+
const std::optional<array>& biases,
132+
const array& out,
133+
bool transpose,
134+
int bits,
135+
int group_size,
136+
QuantizationMode mode,
137+
cu::Device& device) {
138+
int m = out.shape(-2);
139+
int n = out.shape(-1);
140+
int k = x.shape(-1);
141+
int l = out.size() / (m * n);
142+
if (l > 1) {
143+
return false;
144+
}
145+
if (n % 8 != 0 || k % 8 != 0) {
146+
return false;
147+
}
148+
if (!x.flags().row_contiguous || !w.flags().row_contiguous ||
149+
!scales.flags().row_contiguous) {
150+
return false;
151+
}
152+
if (biases && !biases->flags().row_contiguous) {
153+
return false;
154+
}
155+
if (!transpose) {
156+
return false;
157+
}
158+
if (bits % 2 != 0) {
159+
return false;
160+
}
161+
if (mode != QuantizationMode::Affine) {
162+
return false;
163+
}
164+
return true;
165+
}
166+
60167
} // namespace mlx::core

mlx/backend/cuda/quantized/qmm/qmm.h

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,24 @@
33
#pragma once
44

55
#include "mlx/backend/cuda/device.h"
6+
#include "mlx/primitives.h"
67

78
#include <optional>
89

910
namespace mlx::core {
1011

12+
bool supports_qmm_sm90(
13+
const array& x,
14+
const array& w,
15+
const array& scales,
16+
const std::optional<array>& biases,
17+
const array& out,
18+
bool transpose,
19+
int bits,
20+
int group_size,
21+
QuantizationMode mode,
22+
cu::Device& device);
23+
1124
void qmm_sm90(
1225
const array& x,
1326
const array& w,
@@ -19,4 +32,49 @@ void qmm_sm90(
1932
cu::CommandEncoder& encoder,
2033
Stream s);
2134

35+
bool supports_fp_qmv(
36+
const array& x,
37+
const array& w,
38+
const array& scales,
39+
const std::optional<array>& biases,
40+
const array& out,
41+
bool transpose,
42+
int bits,
43+
int group_size,
44+
QuantizationMode mode,
45+
cu::Device& device);
46+
47+
void fp_qmv(
48+
const array& x,
49+
const array& w,
50+
const array& scales,
51+
array& out,
52+
int bits,
53+
int group_size,
54+
cu::CommandEncoder& encoder,
55+
Stream s);
56+
57+
bool supports_qmv(
58+
const array& x,
59+
const array& w,
60+
const array& scales,
61+
const std::optional<array>& biases,
62+
const array& out,
63+
bool transpose,
64+
int bits,
65+
int group_size,
66+
QuantizationMode mode,
67+
cu::Device& device);
68+
69+
void qmv(
70+
const array& x,
71+
const array& w,
72+
const array& scales,
73+
const std::optional<array>& biases,
74+
array& out,
75+
int bits,
76+
int group_size,
77+
QuantizationMode mode,
78+
cu::CommandEncoder& encoder);
79+
2280
} // namespace mlx::core

0 commit comments

Comments
 (0)