44typedef void (*set_rows_kernel_t )(const char * src, char * dst);
55
66// Generic quantized set_rows kernel template
7- template <typename idx_t , typename block_type, int qk, void (*quantize_func)(const float *, block_type*)>
8- static __global__ void k_set_rows_quant (
9- const float * __restrict__ src0, const idx_t * __restrict__ src1, block_type * __restrict__ dst,
10- const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
11- const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
12- const int64_t s01, const int64_t s02, const int64_t s03,
13- const int64_t s10, const int64_t s11, const int64_t s12,
14- const int64_t s1, const int64_t s2, const int64_t s3) {
15-
7+ template <typename idx_t , typename block_type, int qk, void (*quantize_func)(const float *, block_type *)>
8+ static __global__ void k_set_rows_quant (const float * __restrict__ src0,
9+ const idx_t * __restrict__ src1,
10+ block_type * __restrict__ dst,
11+ const int64_t ne_total,
12+ const int64_t ne10,
13+ const int64_t ne11,
14+ const int64_t ne12,
15+ const int64_t ne13,
16+ const int64_t s01,
17+ const int64_t s02,
18+ const int64_t s03,
19+ const int64_t s10,
20+ const int64_t s11,
21+ const int64_t s12,
22+ const int64_t s1,
23+ const int64_t s2,
24+ const int64_t s3,
25+ const uint3 ne00,
26+ const uint3 ne01,
27+ const uint3 ne02,
28+ const uint3 ne11_fd,
29+ const uint3 ne12_fd) {
1630 const int64_t i = int64_t (blockDim .x ) * blockIdx .x + threadIdx .x ;
17- const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk;
1831
1932 if (i >= ne_total) {
2033 return ;
2134 }
2235
2336 const int64_t i_base = i * qk;
24- const int64_t i03 = i_base / (ne00 * ne01 * ne02);
25- const int64_t i02 = (i_base - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
26- const int64_t i01 = (i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
27- const int64_t i00 = i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
37+ uint32_t tmp = (uint32_t ) i_base;
38+ uint2 div_mod;
39+
40+ div_mod = fast_div_modulo (tmp, ne00);
41+ const int64_t i00 = div_mod.y ;
42+ tmp = div_mod.x ;
2843
29- const int64_t i12 = i03 % ne12;
30- const int64_t i11 = i02 % ne11;
44+ div_mod = fast_div_modulo (tmp, ne01);
45+ const int64_t i01 = div_mod.y ;
46+ tmp = div_mod.x ;
47+
48+ div_mod = fast_div_modulo (tmp, ne02);
49+ const int64_t i02 = div_mod.y ;
50+ const int64_t i03 = div_mod.x ;
51+
52+ const int64_t i12 = fastmodulo ((uint32_t ) i03, ne12_fd);
53+ const int64_t i11 = fastmodulo ((uint32_t ) i02, ne11_fd);
3154 const int64_t i10 = i01;
3255
3356 const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
@@ -41,6 +64,8 @@ static __global__ void k_set_rows_quant(
4164 quantize_func (src_block, dst_block);
4265
4366 GGML_UNUSED (ne10);
67+ GGML_UNUSED (ne11);
68+ GGML_UNUSED (ne12);
4469 GGML_UNUSED (ne13);
4570}
4671
@@ -71,40 +96,65 @@ static void set_rows_cuda_quant(
7196 const int64_t s2 = nb2;
7297 const int64_t s3 = nb3;
7398
74- if (ne_total > 0 ) {
99+ if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0 ) {
100+ const uint3 ne00_fd = init_fastdiv_values ((uint32_t ) ne00);
101+ const uint3 ne01_fd = init_fastdiv_values ((uint32_t ) ne01);
102+ const uint3 ne02_fd = init_fastdiv_values ((uint32_t ) ne02);
103+ const uint3 ne11_fd = init_fastdiv_values ((uint32_t ) ne11);
104+ const uint3 ne12_fd = init_fastdiv_values ((uint32_t ) ne12);
105+
75106 k_set_rows_quant<idx_t , block_type, qk, quantize_func><<<grid_size, block_size, 0 , stream>>> (
76- src0_d, src1_d, dst_d,
77- ne00, ne01, ne02, ne03,
78- ne10, ne11, ne12, ne13,
79- s01, s02, s03,
80- s10, s11, s12,
81- s1, s2, s3);
107+ src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd,
108+ ne01_fd, ne02_fd, ne11_fd, ne12_fd);
82109 }
83110}
84111
85- template <typename src_t , typename idx_t , typename dst_t >
86- static __global__ void k_set_rows (
87- const src_t * __restrict__ src0, const idx_t * __restrict__ src1, dst_t * __restrict__ dst,
88- const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
89- const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
90- const int64_t s01, const int64_t s02, const int64_t s03,
91- const int64_t s10, const int64_t s11, const int64_t s12,
92- const int64_t s1, const int64_t s2, const int64_t s3) {
93-
112+ template <typename src_t , typename idx_t , typename dst_t >
113+ static __global__ void k_set_rows (const src_t * __restrict__ src0,
114+ const idx_t * __restrict__ src1,
115+ dst_t * __restrict__ dst,
116+ const int64_t ne_total,
117+ const int64_t ne10,
118+ const int64_t ne11,
119+ const int64_t ne12,
120+ const int64_t ne13,
121+ const int64_t s01,
122+ const int64_t s02,
123+ const int64_t s03,
124+ const int64_t s10,
125+ const int64_t s11,
126+ const int64_t s12,
127+ const int64_t s1,
128+ const int64_t s2,
129+ const int64_t s3,
130+ const uint3 ne00,
131+ const uint3 ne01,
132+ const uint3 ne02,
133+ const uint3 ne11_fd,
134+ const uint3 ne12_fd) {
94135 const int64_t i = int64_t (blockDim .x ) * blockIdx .x + threadIdx .x ;
95- const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
96136
97137 if (i >= ne_total) {
98138 return ;
99139 }
100140
101- const int64_t i03 = i / (ne00 * ne01 * ne02);
102- const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
103- const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
104- const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
141+ uint32_t tmp = (uint32_t ) i;
142+ uint2 div_mod;
143+
144+ div_mod = fast_div_modulo (tmp, ne00);
145+ const int64_t i00 = div_mod.y ;
146+ tmp = div_mod.x ;
105147
106- const int64_t i12 = i03 % ne12;
107- const int64_t i11 = i02 % ne11;
148+ div_mod = fast_div_modulo (tmp, ne01);
149+ const int64_t i01 = div_mod.y ;
150+ tmp = div_mod.x ;
151+
152+ div_mod = fast_div_modulo (tmp, ne02);
153+ const int64_t i02 = div_mod.y ;
154+ const int64_t i03 = div_mod.x ;
155+
156+ const int64_t i12 = fastmodulo ((uint32_t ) i03, ne12_fd);
157+ const int64_t i11 = fastmodulo ((uint32_t ) i02, ne11_fd);
108158 const int64_t i10 = i01;
109159
110160 const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
@@ -115,6 +165,8 @@ static __global__ void k_set_rows(
115165 dst_row_ptr[i00] = ggml_cuda_cast<dst_t >(src0_row[i00]);
116166
117167 GGML_UNUSED (ne10);
168+ GGML_UNUSED (ne11);
169+ GGML_UNUSED (ne12);
118170 GGML_UNUSED (ne13);
119171}
120172
@@ -144,14 +196,16 @@ static void set_rows_cuda(
144196 const int64_t s2 = nb2/sizeof (dst_t );
145197 const int64_t s3 = nb3/sizeof (dst_t );
146198
147- if (ne_total > 0 ) {
148- k_set_rows<<<grid_size, block_size, 0 , stream>>> (
149- src0_d, src1_d, dst_d,
150- ne00, ne01, ne02, ne03,
151- ne10, ne11, ne12, ne13,
152- s01, s02, s03,
153- s10, s11, s12,
154- s1, s2, s3);
199+ if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0 ) {
200+ const uint3 ne00_fd = init_fastdiv_values ((uint32_t ) ne00);
201+ const uint3 ne01_fd = init_fastdiv_values ((uint32_t ) ne01);
202+ const uint3 ne02_fd = init_fastdiv_values ((uint32_t ) ne02);
203+ const uint3 ne11_fd = init_fastdiv_values ((uint32_t ) ne11);
204+ const uint3 ne12_fd = init_fastdiv_values ((uint32_t ) ne12);
205+
206+ k_set_rows<<<grid_size, block_size, 0 , stream>>> (src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01,
207+ s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd,
208+ ne11_fd, ne12_fd);
155209 }
156210}
157211
0 commit comments