@@ -22,8 +22,8 @@ template <int N, typename T, typename Q>
2222__device__ __forceinline__ void
2323dequant_fma (const T* x, const Q* w, T scale, T bias, T* out) {
2424 // Read x/w into registers.
25- auto x_vec = *(reinterpret_cast <const cutlass::AlignedArray <T, N>*>(x));
26- auto w_vec = *(reinterpret_cast <const cutlass::AlignedArray <Q, N>*>(w));
25+ auto x_vec = *(reinterpret_cast <const cutlass::Array <T, N>*>(x));
26+ auto w_vec = *(reinterpret_cast <const cutlass::Array <Q, N>*>(w));
2727 // Output is assumed to be registers.
2828 auto * out_vec = reinterpret_cast <cutlass::Array<T, N>*>(out);
2929
@@ -52,8 +52,8 @@ template <
5252__device__ __forceinline__ void
5353dequant_fma (const T* x, const Q* w, T scale, T bias, float * out) {
5454 // Read x/w into registers.
55- auto x_vec = *(reinterpret_cast <const cutlass::AlignedArray <T, N>*>(x));
56- auto w_vec = *(reinterpret_cast <const cutlass::AlignedArray <Q, N>*>(w));
55+ auto x_vec = *(reinterpret_cast <const cutlass::Array <T, N>*>(x));
56+ auto w_vec = *(reinterpret_cast <const cutlass::Array <Q, N>*>(w));
5757 // Output is assumed to be registers.
5858 auto * out_vec = reinterpret_cast <cutlass::Array<float , N>*>(out);
5959
@@ -87,7 +87,9 @@ __global__ void qmv_kernel(
8787 const T* biases,
8888 T* out,
8989 int n,
90- int k) {
90+ int k,
91+ bool broadcast_w) {
92+ auto grid = cg::this_grid ();
9193 auto block = cg::this_thread_block ();
9294 auto warp = cg::tiled_partition<WARP_SIZE>(block);
9395
@@ -98,8 +100,10 @@ __global__ void qmv_kernel(
98100 }
99101
100102 // Advance pointers of x/out.
101- x += block.group_index ().y * k;
102- out += block.group_index ().y * n;
103+ int m = grid.dim_blocks ().y ;
104+ int l = block.group_index ().z ;
105+ x += block.group_index ().y * k + m * k * l;
106+ out += block.group_index ().y * n + m * n * l;
103107
104108 // For sub-byte Q, pointer moves by 8bits for each advance, e.g. w += 1 would
105109 // move past 2 elements for 4-bit Q.
@@ -110,10 +114,11 @@ __global__ void qmv_kernel(
110114 int groups_per_row = k / group_size;
111115
112116 // Advance w/scales/biases to current row.
113- w += static_cast <int64_t >(row) * k / w_step;
114- scales += static_cast <int64_t >(row) * groups_per_row;
117+ int w_batch = broadcast_w ? 0 : l;
118+ w += (static_cast <int64_t >(row) + n * w_batch) * k / w_step;
119+ scales += (static_cast <int64_t >(row) + n * w_batch) * groups_per_row;
115120 if constexpr (has_bias) {
116- biases += static_cast <int64_t >(row) * groups_per_row;
121+ biases += ( static_cast <int64_t >(row) + n * w_batch ) * groups_per_row;
117122 }
118123
119124 // Accumulations of current row.
@@ -168,14 +173,17 @@ void qmv(
168173 int m,
169174 int n,
170175 int k,
176+ int l,
177+ bool broadcast_w,
171178 F&& launch_kernel) {
172179 constexpr int rows_per_block = 8 ;
173180 constexpr int elems_per_thread =
174181 (cute::sizeof_bits_v<T> <= 16 && cute::sizeof_bits_v<Q> <= 4 ) ? 16 : 8 ;
175182
176- dim3 num_blocks{uint32_t (cuda::ceil_div (n, rows_per_block)), uint32_t (m)};
183+ dim3 num_blocks{
184+ uint32_t (cuda::ceil_div (n, rows_per_block)), uint32_t (m), uint32_t (l)};
177185 dim3 block_dims{WARP_SIZE, rows_per_block};
178- void * args[] = {&x, &w, &scales, &biases, &out, &n, &k};
186+ void * args[] = {&x, &w, &scales, &biases, &out, &n, &k, &broadcast_w };
179187
180188 dispatch_bool (k % (WARP_SIZE * elems_per_thread), [&](auto has_residue_k) {
181189 auto * kernel = &qmv_kernel<
@@ -207,34 +215,9 @@ inline void dispatch_element_types(Dtype dtype, const char* tag, F&& f) {
207215 }
208216}
209217
210- template <typename F>
211- inline void
212- dispatch_quant_types (int bits, QuantizationMode mode, const char * tag, F&& f) {
213- if (mode == QuantizationMode::Mxfp4) {
214- f.template operator ()<cutlass::float_e2m1_t >();
215- } else if (mode == QuantizationMode::Mxfp8) {
216- f.template operator ()<cutlass::float_e4m3_t >();
217- } else if (mode == QuantizationMode::Nvfp4) {
218- f.template operator ()<cutlass::float_e2m1_t >();
219- } else {
220- if (bits == 2 ) {
221- f.template operator ()<cutlass::uint2b_t >();
222- } else if (bits == 4 ) {
223- f.template operator ()<cutlass::uint4b_t >();
224- } else if (bits == 8 ) {
225- f.template operator ()<uint8_t >();
226- } else {
227- throw std::invalid_argument (
228- fmt::format (" {} {}-bit quantization is not supported." , tag, bits));
229- }
230- }
231- }
232-
233218template <typename F>
234219inline void dispatch_groups (int group_size, const char * tag, F&& f) {
235- if (group_size == 16 ) {
236- f.template operator ()<16 >();
237- } else if (group_size == 32 ) {
220+ if (group_size == 32 ) {
238221 f.template operator ()<32 >();
239222 } else if (group_size == 64 ) {
240223 f.template operator ()<64 >();
@@ -246,6 +229,35 @@ inline void dispatch_groups(int group_size, const char* tag, F&& f) {
246229 }
247230}
248231
232+ template <typename F>
233+ inline void dispatch_quant_types (
234+ int bits,
235+ int group_size,
236+ QuantizationMode mode,
237+ const char * tag,
238+ F&& f) {
239+ if (mode == QuantizationMode::Mxfp4) {
240+ f.template operator ()<cutlass::float_e2m1_t , 16 >();
241+ } else if (mode == QuantizationMode::Mxfp8) {
242+ f.template operator ()<cutlass::float_e4m3_t , 32 >();
243+ } else if (mode == QuantizationMode::Nvfp4) {
244+ f.template operator ()<cutlass::float_e2m1_t , 32 >();
245+ } else {
246+ dispatch_groups (group_size, tag, [&]<int group_size>() {
247+ if (bits == 2 ) {
248+ f.template operator ()<cutlass::uint2b_t , group_size>();
249+ } else if (bits == 4 ) {
250+ f.template operator ()<cutlass::uint4b_t , group_size>();
251+ } else if (bits == 8 ) {
252+ f.template operator ()<uint8_t , group_size>();
253+ } else {
254+ throw std::invalid_argument (
255+ fmt::format (" {} {}-bit quantization is not supported." , tag, bits));
256+ }
257+ });
258+ }
259+ }
260+
249261void qmv (
250262 const array& x,
251263 const array& w,
@@ -260,19 +272,21 @@ void qmv(
260272 int m = out.shape (-2 );
261273 int n = out.shape (-1 );
262274 int k = x.shape (-1 );
275+ int l = out.size () / (m * n);
276+ bool broadcast_w = w.ndim () == 2 ;
263277
264278 dispatch_element_types (out.dtype (), tag, [&]<typename T>() {
265- dispatch_bool (biases.has_value (), [&](auto has_bias) {
266- dispatch_quant_types (bits, mode, tag, [&]<typename Q>() {
267- dispatch_groups (group_size, tag, [&]<int group_size>() {
279+ dispatch_quant_types (
280+ bits, group_size, mode, tag, [&]<typename Q, int group_size>() {
268281 encoder.set_input_array (x);
269282 encoder.set_input_array (w);
270283 encoder.set_input_array (scales);
271284 if (biases) {
272285 encoder.set_input_array (*biases);
273286 }
274287 encoder.set_output_array (out);
275- cu::qmv<group_size, has_bias.value >(
288+ constexpr bool has_bias = !cutlass::has_negative_zero_v<Q>;
289+ cu::qmv<group_size, has_bias>(
276290 gpu_ptr<T>(x),
277291 gpu_ptr<Q>(w),
278292 gpu_ptr<T>(scales),
@@ -281,13 +295,13 @@ void qmv(
281295 m,
282296 n,
283297 k,
298+ l,
299+ broadcast_w,
284300 [&](auto * kernel, dim3 num_blocks, dim3 block_dims, void ** args) {
285301 encoder.add_kernel_node_raw (
286302 kernel, num_blocks, block_dims, {}, 0 , args);
287303 });
288304 });
289- });
290- });
291305 });
292306}
293307
0 commit comments