Skip to content

Commit 5a347b2

Browse files
authored
[CUDA] Faster compilation and batch support in QMV (#3213)
1 parent db487f3 commit 5a347b2

File tree

3 files changed

+64
-51
lines changed

3 files changed

+64
-51
lines changed

mlx/backend/cuda/quantized/qmm/qmm.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -135,14 +135,8 @@ bool supports_qmv(
135135
int group_size,
136136
QuantizationMode mode,
137137
cu::Device& device) {
138-
int m = out.shape(-2);
139-
int n = out.shape(-1);
140138
int k = x.shape(-1);
141-
int l = out.size() / (m * n);
142-
if (l > 1) {
143-
return false;
144-
}
145-
if (n % 8 != 0 || k % 8 != 0) {
139+
if (k % 8 != 0) {
146140
return false;
147141
}
148142
if (!x.flags().row_contiguous || !w.flags().row_contiguous ||

mlx/backend/cuda/quantized/qmm/qmv.cu

Lines changed: 58 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ template <int N, typename T, typename Q>
2222
__device__ __forceinline__ void
2323
dequant_fma(const T* x, const Q* w, T scale, T bias, T* out) {
2424
// Read x/w into registers.
25-
auto x_vec = *(reinterpret_cast<const cutlass::AlignedArray<T, N>*>(x));
26-
auto w_vec = *(reinterpret_cast<const cutlass::AlignedArray<Q, N>*>(w));
25+
auto x_vec = *(reinterpret_cast<const cutlass::Array<T, N>*>(x));
26+
auto w_vec = *(reinterpret_cast<const cutlass::Array<Q, N>*>(w));
2727
// Output is assumed to be registers.
2828
auto* out_vec = reinterpret_cast<cutlass::Array<T, N>*>(out);
2929

@@ -52,8 +52,8 @@ template <
5252
__device__ __forceinline__ void
5353
dequant_fma(const T* x, const Q* w, T scale, T bias, float* out) {
5454
// Read x/w into registers.
55-
auto x_vec = *(reinterpret_cast<const cutlass::AlignedArray<T, N>*>(x));
56-
auto w_vec = *(reinterpret_cast<const cutlass::AlignedArray<Q, N>*>(w));
55+
auto x_vec = *(reinterpret_cast<const cutlass::Array<T, N>*>(x));
56+
auto w_vec = *(reinterpret_cast<const cutlass::Array<Q, N>*>(w));
5757
// Output is assumed to be registers.
5858
auto* out_vec = reinterpret_cast<cutlass::Array<float, N>*>(out);
5959

@@ -87,7 +87,9 @@ __global__ void qmv_kernel(
8787
const T* biases,
8888
T* out,
8989
int n,
90-
int k) {
90+
int k,
91+
bool broadcast_w) {
92+
auto grid = cg::this_grid();
9193
auto block = cg::this_thread_block();
9294
auto warp = cg::tiled_partition<WARP_SIZE>(block);
9395

@@ -98,8 +100,10 @@ __global__ void qmv_kernel(
98100
}
99101

100102
// Advance pointers of x/out.
101-
x += block.group_index().y * k;
102-
out += block.group_index().y * n;
103+
int m = grid.dim_blocks().y;
104+
int l = block.group_index().z;
105+
x += block.group_index().y * k + m * k * l;
106+
out += block.group_index().y * n + m * n * l;
103107

104108
// For sub-byte Q, pointer moves by 8bits for each advance, e.g. w += 1 would
105109
// move past 2 elements for 4-bit Q.
@@ -110,10 +114,11 @@ __global__ void qmv_kernel(
110114
int groups_per_row = k / group_size;
111115

112116
// Advance w/scales/biases to current row.
113-
w += static_cast<int64_t>(row) * k / w_step;
114-
scales += static_cast<int64_t>(row) * groups_per_row;
117+
int w_batch = broadcast_w ? 0 : l;
118+
w += (static_cast<int64_t>(row) + n * w_batch) * k / w_step;
119+
scales += (static_cast<int64_t>(row) + n * w_batch) * groups_per_row;
115120
if constexpr (has_bias) {
116-
biases += static_cast<int64_t>(row) * groups_per_row;
121+
biases += (static_cast<int64_t>(row) + n * w_batch) * groups_per_row;
117122
}
118123

119124
// Accumulations of current row.
@@ -168,14 +173,17 @@ void qmv(
168173
int m,
169174
int n,
170175
int k,
176+
int l,
177+
bool broadcast_w,
171178
F&& launch_kernel) {
172179
constexpr int rows_per_block = 8;
173180
constexpr int elems_per_thread =
174181
(cute::sizeof_bits_v<T> <= 16 && cute::sizeof_bits_v<Q> <= 4) ? 16 : 8;
175182

176-
dim3 num_blocks{uint32_t(cuda::ceil_div(n, rows_per_block)), uint32_t(m)};
183+
dim3 num_blocks{
184+
uint32_t(cuda::ceil_div(n, rows_per_block)), uint32_t(m), uint32_t(l)};
177185
dim3 block_dims{WARP_SIZE, rows_per_block};
178-
void* args[] = {&x, &w, &scales, &biases, &out, &n, &k};
186+
void* args[] = {&x, &w, &scales, &biases, &out, &n, &k, &broadcast_w};
179187

180188
dispatch_bool(k % (WARP_SIZE * elems_per_thread), [&](auto has_residue_k) {
181189
auto* kernel = &qmv_kernel<
@@ -207,34 +215,9 @@ inline void dispatch_element_types(Dtype dtype, const char* tag, F&& f) {
207215
}
208216
}
209217

210-
template <typename F>
211-
inline void
212-
dispatch_quant_types(int bits, QuantizationMode mode, const char* tag, F&& f) {
213-
if (mode == QuantizationMode::Mxfp4) {
214-
f.template operator()<cutlass::float_e2m1_t>();
215-
} else if (mode == QuantizationMode::Mxfp8) {
216-
f.template operator()<cutlass::float_e4m3_t>();
217-
} else if (mode == QuantizationMode::Nvfp4) {
218-
f.template operator()<cutlass::float_e2m1_t>();
219-
} else {
220-
if (bits == 2) {
221-
f.template operator()<cutlass::uint2b_t>();
222-
} else if (bits == 4) {
223-
f.template operator()<cutlass::uint4b_t>();
224-
} else if (bits == 8) {
225-
f.template operator()<uint8_t>();
226-
} else {
227-
throw std::invalid_argument(
228-
fmt::format("{} {}-bit quantization is not supported.", tag, bits));
229-
}
230-
}
231-
}
232-
233218
template <typename F>
234219
inline void dispatch_groups(int group_size, const char* tag, F&& f) {
235-
if (group_size == 16) {
236-
f.template operator()<16>();
237-
} else if (group_size == 32) {
220+
if (group_size == 32) {
238221
f.template operator()<32>();
239222
} else if (group_size == 64) {
240223
f.template operator()<64>();
@@ -246,6 +229,35 @@ inline void dispatch_groups(int group_size, const char* tag, F&& f) {
246229
}
247230
}
248231

232+
template <typename F>
233+
inline void dispatch_quant_types(
234+
int bits,
235+
int group_size,
236+
QuantizationMode mode,
237+
const char* tag,
238+
F&& f) {
239+
if (mode == QuantizationMode::Mxfp4) {
240+
f.template operator()<cutlass::float_e2m1_t, 16>();
241+
} else if (mode == QuantizationMode::Mxfp8) {
242+
f.template operator()<cutlass::float_e4m3_t, 32>();
243+
} else if (mode == QuantizationMode::Nvfp4) {
244+
f.template operator()<cutlass::float_e2m1_t, 32>();
245+
} else {
246+
dispatch_groups(group_size, tag, [&]<int group_size>() {
247+
if (bits == 2) {
248+
f.template operator()<cutlass::uint2b_t, group_size>();
249+
} else if (bits == 4) {
250+
f.template operator()<cutlass::uint4b_t, group_size>();
251+
} else if (bits == 8) {
252+
f.template operator()<uint8_t, group_size>();
253+
} else {
254+
throw std::invalid_argument(
255+
fmt::format("{} {}-bit quantization is not supported.", tag, bits));
256+
}
257+
});
258+
}
259+
}
260+
249261
void qmv(
250262
const array& x,
251263
const array& w,
@@ -260,19 +272,21 @@ void qmv(
260272
int m = out.shape(-2);
261273
int n = out.shape(-1);
262274
int k = x.shape(-1);
275+
int l = out.size() / (m * n);
276+
bool broadcast_w = w.ndim() == 2;
263277

264278
dispatch_element_types(out.dtype(), tag, [&]<typename T>() {
265-
dispatch_bool(biases.has_value(), [&](auto has_bias) {
266-
dispatch_quant_types(bits, mode, tag, [&]<typename Q>() {
267-
dispatch_groups(group_size, tag, [&]<int group_size>() {
279+
dispatch_quant_types(
280+
bits, group_size, mode, tag, [&]<typename Q, int group_size>() {
268281
encoder.set_input_array(x);
269282
encoder.set_input_array(w);
270283
encoder.set_input_array(scales);
271284
if (biases) {
272285
encoder.set_input_array(*biases);
273286
}
274287
encoder.set_output_array(out);
275-
cu::qmv<group_size, has_bias.value>(
288+
constexpr bool has_bias = !cutlass::has_negative_zero_v<Q>;
289+
cu::qmv<group_size, has_bias>(
276290
gpu_ptr<T>(x),
277291
gpu_ptr<Q>(w),
278292
gpu_ptr<T>(scales),
@@ -281,13 +295,13 @@ void qmv(
281295
m,
282296
n,
283297
k,
298+
l,
299+
broadcast_w,
284300
[&](auto* kernel, dim3 num_blocks, dim3 block_dims, void** args) {
285301
encoder.add_kernel_node_raw(
286302
kernel, num_blocks, block_dims, {}, 0, args);
287303
});
288304
});
289-
});
290-
});
291305
});
292306
}
293307

mlx/backend/cuda/quantized/quantized.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,12 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
8888
throw std::runtime_error(
8989
fmt::format(
9090
"[quantized_matmul] No implementation for "
91+
"problem shape: {}x{}x{}x{} "
9192
"activation: {}, bits: {}, group size: {}, mode: \"{}\".",
93+
M,
94+
N,
95+
K,
96+
B,
9297
dtype_to_string(x.dtype()),
9398
bits_,
9499
group_size_,

0 commit comments

Comments
 (0)