22#include " dequantize.cuh"
33#include " convert.cuh"
44
5- #define MAX_GRIDDIM_Y 65535
6-
75template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t >
86static __global__ void k_get_rows (
97 const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
108 const int64_t ne00, /* const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
11- /* const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /* const int64_t ne13,*/
9+ /* const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /* const int64_t ne13,*/
1210 /* const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
1311 /* const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
1412 const size_t s10, const size_t s11, const size_t s12/* , const size_t s13*/ ) {
1513
16- for (int64_t i00 = 2 *(blockIdx .y *blockDim .x + threadIdx .x ); i00 < ne00; i00 += gridDim .y *blockDim .x ) {
17- // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
18- const int i10 = blockIdx .x ;
19- const int i11 = blockIdx .z / ne12;
20- const int i12 = blockIdx .z % ne12;
14+ for (int64_t z = blockIdx .z ; z < ne11*ne12; z += gridDim .z ) {
15+ for (int64_t i00 = 2 *(blockIdx .y *blockDim .x + threadIdx .x ); i00 < ne00; i00 += gridDim .y *blockDim .x ) {
16+ // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
17+ const int i10 = blockIdx .x ;
18+ const int i11 = z / ne12; // TODO fastdiv
19+ const int i12 = z % ne12;
2120
22- const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
21+ const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
2322
24- dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
25- const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
23+ dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
24+ const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
2625
27- const int ib = i00/qk; // block index
28- const int iqs = (i00%qk)/qr; // quant index
29- const int iybs = i00 - i00%qk; // dst block start index
30- const int y_offset = qr == 1 ? 1 : qk/2 ;
26+ const int ib = i00/qk; // block index
27+ const int iqs = (i00%qk)/qr; // quant index
28+ const int iybs = i00 - i00%qk; // dst block start index
29+ const int y_offset = qr == 1 ? 1 : qk/2 ;
3130
32- // dequantize
33- float2 v;
34- dequantize_kernel (src0_row, ib, iqs, v);
31+ // dequantize
32+ float2 v;
33+ dequantize_kernel (src0_row, ib, iqs, v);
3534
36- dst_row[iybs + iqs + 0 ] = ggml_cuda_cast<dst_t >(v.x );
37- dst_row[iybs + iqs + y_offset] = ggml_cuda_cast<dst_t >(v.y );
35+ dst_row[iybs + iqs + 0 ] = ggml_cuda_cast<dst_t >(v.x );
36+ dst_row[iybs + iqs + y_offset] = ggml_cuda_cast<dst_t >(v.y );
37+ }
3838 }
3939}
4040
4141template <typename src0_t , typename dst_t >
4242static __global__ void k_get_rows_float (
4343 const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
4444 const int64_t ne00, /* const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
45- /* const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /* const int64_t ne13,*/
45+ /* const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /* const int64_t ne13,*/
4646 /* const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
4747 /* const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
4848 const size_t s10, const size_t s11, const size_t s12/* , const size_t s13*/ ) {
4949
50- for (int64_t i00 = blockIdx .y *blockDim .x + threadIdx .x ; i00 < ne00; i00 += gridDim .y *blockDim .x ) {
51- // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
52- const int i10 = blockIdx .x ;
53- const int i11 = blockIdx .z / ne12;
54- const int i12 = blockIdx .z % ne12;
50+ for (int64_t z = blockIdx .z ; z < ne11*ne12; z += gridDim .z ) {
51+ for (int64_t i00 = blockIdx .y *blockDim .x + threadIdx .x ; i00 < ne00; i00 += gridDim .y *blockDim .x ) {
52+ // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
53+ const int i10 = blockIdx .x ;
54+ const int i11 = z / ne12; // TODO fastdiv
55+ const int i12 = z % ne12;
5556
56- if (i00 >= ne00) {
57- return ;
58- }
57+ if (i00 >= ne00) {
58+ return ;
59+ }
5960
60- const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
61+ const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
6162
62- dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
63- const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
63+ dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
64+ const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
6465
65- dst_row[i00] = ggml_cuda_cast<dst_t >(src0_row[i00]);
66+ dst_row[i00] = ggml_cuda_cast<dst_t >(src0_row[i00]);
67+ }
6668 }
6769}
6870
@@ -98,7 +100,7 @@ static void get_rows_cuda_q(
98100 cudaStream_t stream) {
99101 const dim3 block_dims (CUDA_GET_ROWS_BLOCK_SIZE, 1 , 1 );
100102 const int block_num_y = (ne00 + 2 *CUDA_GET_ROWS_BLOCK_SIZE - 1 ) / (2 *CUDA_GET_ROWS_BLOCK_SIZE);
101- const dim3 block_nums (ne10, MIN (block_num_y, MAX_GRIDDIM_Y ), ne11*ne12);
103+ const dim3 block_nums (ne10, MIN (block_num_y, UINT16_MAX ), MIN ( ne11*ne12, UINT16_MAX) );
102104
103105 // strides in elements
104106 // const size_t s0 = nb0 / sizeof(dst_t);
@@ -116,7 +118,7 @@ static void get_rows_cuda_q(
116118 k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0 , stream>>> (
117119 src0_d, src1_d, dst_d,
118120 ne00, /* ne01, ne02, ne03,*/
119- /* ne10, ne11,*/ ne12, /* ne13,*/
121+ /* ne10,*/ ne11, ne12, /* ne13,*/
120122 /* s0,*/ s1, s2, s3,
121123 /* nb00,*/ nb01, nb02, nb03,
122124 s10, s11, s12/* , s13*/ );
@@ -131,7 +133,7 @@ static void get_rows_cuda_float(
131133 cudaStream_t stream) {
132134 const dim3 block_dims (CUDA_GET_ROWS_BLOCK_SIZE, 1 , 1 );
133135 const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1 ) / CUDA_GET_ROWS_BLOCK_SIZE;
134- const dim3 block_nums (ne10, MIN (block_num_y, MAX_GRIDDIM_Y ), ne11*ne12);
136+ const dim3 block_nums (ne10, MIN (block_num_y, UINT16_MAX ), MIN ( ne11*ne12, UINT16_MAX) );
135137
136138 // strides in elements
137139 // const size_t s0 = nb0 / sizeof(dst_t);
@@ -147,7 +149,7 @@ static void get_rows_cuda_float(
147149 k_get_rows_float<<<block_nums, block_dims, 0 , stream>>> (
148150 src0_d, src1_d, dst_d,
149151 ne00, /* ne01, ne02, ne03,*/
150- /* ne10, ne11,*/ ne12, /* ne13,*/
152+ /* ne10,*/ ne11, ne12, /* ne13,*/
151153 /* s0,*/ s1, s2, s3,
152154 /* nb00,*/ nb01, nb02, nb03,
153155 s10, s11, s12/* , s13*/ );
0 commit comments