|
1 | 1 | #include "quantize.cuh" |
2 | 2 | #include <cstdint> |
3 | 3 |
|
4 | | -static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded) { |
5 | | - const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; |
| 4 | +static __global__ void quantize_q8_1( |
| 5 | + const float * __restrict__ x, void * __restrict__ vy, |
| 6 | + const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, |
| 7 | + const int64_t ne0, const int ne1, const int ne2) { |
| 8 | + const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; |
6 | 9 |
|
7 | | - if (ix0 >= kx0_padded) { |
| 10 | + if (i0 >= ne0) { |
8 | 11 | return; |
9 | 12 | } |
10 | 13 |
|
11 | | - const int64_t ix1 = blockIdx.y; |
| 14 | + const int64_t i1 = blockIdx.y; |
| 15 | + const int64_t i2 = blockIdx.z % ne2; |
| 16 | + const int64_t i3 = blockIdx.z / ne2; |
| 17 | + |
| 18 | + const int64_t & i00 = i0; |
| 19 | + const int64_t & i01 = i1; |
| 20 | + const int64_t & i02 = i2; |
| 21 | + const int64_t & i03 = i3; |
12 | 22 |
|
13 | | - const int64_t i_padded = ix1*kx0_padded + ix0; |
| 23 | + const int64_t i_cont = ((i3*ne2 + i2) * ne1 + i1) * ne0 + i0; |
14 | 24 |
|
15 | 25 | block_q8_1 * y = (block_q8_1 *) vy; |
16 | 26 |
|
17 | | - const int64_t ib = i_padded / QK8_1; // block index |
18 | | - const int64_t iqs = i_padded % QK8_1; // quant index |
| 27 | + const int64_t ib = i_cont / QK8_1; // block index |
| 28 | + const int64_t iqs = i_cont % QK8_1; // quant index |
19 | 29 |
|
20 | | - const float xi = ix0 < kx ? x[ix1*kx + ix0] : 0.0f; |
| 30 | + const float xi = i0 < ne00 ? x[i03*s03 + i02*s02 + i01*s01 + i00] : 0.0f; |
21 | 31 | float amax = fabsf(xi); |
22 | 32 | float sum = xi; |
23 | 33 |
|
24 | 34 | amax = warp_reduce_max(amax); |
25 | | - sum = warp_reduce_sum(sum); |
| 35 | + sum = warp_reduce_sum(sum); |
26 | 36 |
|
27 | | - const float d = amax / 127; |
| 37 | + const float d = amax / 127; |
28 | 38 | const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); |
29 | 39 |
|
30 | 40 | y[ib].qs[iqs] = q; |
@@ -124,43 +134,45 @@ static __global__ void quantize_mmq_q8_1( |
124 | 134 | } |
125 | 135 |
|
126 | 136 | void quantize_row_q8_1_cuda( |
127 | | - const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, |
128 | | - const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) { |
| 137 | + const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, |
| 138 | + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) { |
129 | 139 |
|
130 | | - GGML_ASSERT(kx0_padded % QK8_1 == 0); |
| 140 | + GGML_ASSERT(ne0 % QK8_1 == 0); |
131 | 141 |
|
132 | | - const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; |
133 | | - const dim3 num_blocks(block_num_x, kx1*channels, 1); |
| 142 | + const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; |
| 143 | + const dim3 num_blocks(block_num_x, ne1, ne2*ne3); |
134 | 144 | const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); |
135 | | - quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx0_padded); |
136 | | - |
137 | | - GGML_UNUSED(type_x); |
| 145 | + quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2); |
| 146 | + GGML_UNUSED(type_src0); |
138 | 147 | } |
139 | 148 |
|
140 | 149 | void quantize_mmq_q8_1_cuda( |
141 | | - const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, |
142 | | - const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) { |
| 150 | + const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, |
| 151 | + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) { |
143 | 152 |
|
144 | | - GGML_ASSERT(kx0_padded % (4*QK8_1) == 0); |
| 153 | + GGML_ASSERT(ne0 % (4*QK8_1) == 0); |
145 | 154 |
|
146 | | - const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ); |
147 | | - const dim3 num_blocks(block_num_x, kx1, channels); |
| 155 | + const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ); |
| 156 | + const dim3 num_blocks(block_num_x, ne1, ne2*ne3); |
148 | 157 | const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1); |
149 | | - switch (mmq_get_q8_1_ds_layout(type_x)) { |
| 158 | + switch (mmq_get_q8_1_ds_layout(type_src0)) { |
150 | 159 | case MMQ_Q8_1_DS_LAYOUT_D4: |
151 | 160 | quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4> |
152 | | - <<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded); |
| 161 | + <<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, ne1, ne0); |
153 | 162 | break; |
154 | 163 | case MMQ_Q8_1_DS_LAYOUT_DS4: |
155 | 164 | quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4> |
156 | | - <<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded); |
| 165 | + <<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, ne1, ne0); |
157 | 166 | break; |
158 | 167 | case MMQ_Q8_1_DS_LAYOUT_D2S6: |
159 | 168 | quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6> |
160 | | - <<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded); |
| 169 | + <<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, ne1, ne0); |
161 | 170 | break; |
162 | 171 | default: |
163 | 172 | GGML_ABORT("fatal error"); |
164 | 173 | break; |
165 | 174 | } |
| 175 | + GGML_UNUSED(s01); |
| 176 | + GGML_UNUSED(s02); |
| 177 | + GGML_UNUSED(s03); |
166 | 178 | } |
0 commit comments