@@ -20,6 +20,36 @@ namespace cg = cooperative_groups;
2020// out = fma(x, w_dq, out)
2121template <int N, typename T, typename Q>
2222__device__ __forceinline__ void
23+ dequant_fma (const T* x, const Q* w, T scale, T bias, T* out) {
24+ // Read x/w into registers.
25+ auto x_vec = *(reinterpret_cast <const cutlass::AlignedArray<T, N>*>(x));
26+ auto w_vec = *(reinterpret_cast <const cutlass::AlignedArray<Q, N>*>(w));
27+ // Output is assumed to be registers.
28+ auto * out_vec = reinterpret_cast <cutlass::Array<T, N>*>(out);
29+
30+ // Dequantize w.
31+ cutlass::NumericArrayConverter<T, Q, N> converter_tq;
32+ cutlass::Array<T, N> w_dq = converter_tq (w_vec);
33+ if constexpr (cuda::std::is_same_v<T, float >) {
34+ #pragma unroll
35+ for (int i = 0 ; i < N; ++i) {
36+ w_dq[i] = w_dq[i] * scale + bias;
37+ }
38+ } else {
39+ w_dq = w_dq * scale + bias;
40+ }
41+
42+ // Multiply and add.
43+ *out_vec = cutlass::fma (x_vec, w_dq, *out_vec);
44+ }
45+
46+ // Specialization for doing float32 accumulations on narrow types.
47+ template <
48+ int N,
49+ typename T,
50+ typename Q,
51+ typename = cuda::std::enable_if_t <!cuda::std::is_same_v<T, float >>>
52+ __device__ __forceinline__ void
2353dequant_fma (const T* x, const Q* w, T scale, T bias, float * out) {
2454 // Read x/w into registers.
2555 auto x_vec = *(reinterpret_cast <const cutlass::AlignedArray<T, N>*>(x));
@@ -42,24 +72,6 @@ dequant_fma(const T* x, const Q* w, T scale, T bias, float* out) {
4272 *out_vec = cutlass::fma (x_f, w_f, *out_vec);
4373}
4474
45- // Specialized for float which does not need promotions.
46- template <int N, typename Q>
47- __device__ __forceinline__ void
48- dequant_fma (const float * x, const Q* w, float scale, float bias, float * out) {
49- auto x_vec = *(reinterpret_cast <const cutlass::AlignedArray<float , N>*>(x));
50- auto w_vec = *(reinterpret_cast <const cutlass::AlignedArray<Q, N>*>(w));
51- auto * out_vec = reinterpret_cast <cutlass::Array<float , N>*>(out);
52-
53- cutlass::NumericArrayConverter<float , Q, N> converter;
54- cutlass::Array<float , N> w_dq = converter (w_vec);
55- #pragma unroll
56- for (int i = 0 ; i < N; ++i) {
57- w_dq[i] = w_dq[i] * scale + bias;
58- }
59-
60- *out_vec = cutlass::fma (x_vec, w_dq, *out_vec);
61- }
62-
6375template <
6476 int rows_per_block,
6577 int elems_per_thread,
@@ -91,7 +103,8 @@ __global__ void qmv_kernel(
91103
92104 // For sub-byte Q, pointer moves by 8bits for each advance, e.g. w += 1 would
93105 // move past 2 elements for 4-bit Q.
94- constexpr int w_step = 8 / cuda::std::min (8 , cute::sizeof_bits_v<Q>);
106+ constexpr int bits = cute::sizeof_bits_v<Q>;
107+ constexpr int w_step = 8 / cuda::std::min (8 , bits);
95108
96109 // How many groups (and scales/biases) in a row.
97110 int groups_per_row = k / group_size;
@@ -104,7 +117,7 @@ __global__ void qmv_kernel(
104117 }
105118
106119 // Accumulations of current row.
107- float sums[elems_per_thread] = {};
120+ cuda::std:: conditional_t <(bits >= 8 ), float , T> sums[elems_per_thread] = {};
108121
109122 auto dequant_fma_tile = [&](int idx) {
110123 T scale = scales[idx / group_size];
@@ -157,7 +170,8 @@ void qmv(
157170 int k,
158171 F&& launch_kernel) {
159172 constexpr int rows_per_block = 8 ;
160- constexpr int elems_per_thread = 8 ;
173+ constexpr int elems_per_thread =
174+ (cute::sizeof_bits_v<T> <= 16 && cute::sizeof_bits_v<Q> <= 4 ) ? 16 : 8 ;
161175
162176 dim3 num_blocks{uint32_t (cuda::ceil_div (n, rows_per_block)), uint32_t (m)};
163177 dim3 block_dims{WARP_SIZE, rows_per_block};
0 commit comments