44typedef  void  (*set_rows_kernel_t )(const  char  * src, char  * dst);
55
66//  Generic quantized set_rows kernel template
7- template <typename  idx_t , typename  block_type, int  qk, void  (*quantize_func)(const  float *, block_type*)>
8- static  __global__  void  k_set_rows_quant (
9-         const  float  * __restrict__  src0, const  idx_t  * __restrict__  src1, block_type * __restrict__  dst,
10-         const  int64_t  ne00, const  int64_t  ne01, const  int64_t  ne02, const  int64_t  ne03,
11-         const  int64_t  ne10, const  int64_t  ne11, const  int64_t  ne12, const  int64_t  ne13,
12-         const  int64_t  s01, const  int64_t  s02, const  int64_t  s03,
13-         const  int64_t  s10, const  int64_t  s11, const  int64_t  s12,
14-         const  int64_t  s1, const  int64_t  s2, const  int64_t  s3) {
15- 
7+ template  <typename  idx_t , typename  block_type, int  qk, void  (*quantize_func)(const  float  *, block_type *)>
8+ static  __global__  void  k_set_rows_quant (const  float  * __restrict__  src0,
9+                                         const  idx_t  * __restrict__  src1,
10+                                         block_type * __restrict__  dst,
11+                                         const  int64_t  ne_total,
12+                                         const  int64_t  ne10,
13+                                         const  int64_t  ne11,
14+                                         const  int64_t  ne12,
15+                                         const  int64_t  ne13,
16+                                         const  int64_t  s01,
17+                                         const  int64_t  s02,
18+                                         const  int64_t  s03,
19+                                         const  int64_t  s10,
20+                                         const  int64_t  s11,
21+                                         const  int64_t  s12,
22+                                         const  int64_t  s1,
23+                                         const  int64_t  s2,
24+                                         const  int64_t  s3,
25+                                         const  uint3    ne00,
26+                                         const  uint3    ne01,
27+                                         const  uint3    ne02,
28+                                         const  uint3    ne11_fd,
29+                                         const  uint3    ne12_fd) {
1630    const  int64_t  i = int64_t (blockDim .x ) * blockIdx .x  + threadIdx .x ;
17-     const  int64_t  ne_total = (ne00 * ne01 * ne02 * ne03) / qk;
1831
1932    if  (i >= ne_total) {
2033        return ;
2134    }
2235
2336    const  int64_t  i_base = i * qk;
24-     const  int64_t  i03 = i_base / (ne00 * ne01 * ne02);
25-     const  int64_t  i02 = (i_base - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
26-     const  int64_t  i01 = (i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
27-     const  int64_t  i00 = i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
37+     uint32_t       tmp    = (uint32_t ) i_base;
38+     uint2          div_mod;
39+ 
40+     div_mod           = fast_div_modulo (tmp, ne00);
41+     const  int64_t  i00 = div_mod.y ;
42+     tmp               = div_mod.x ;
2843
29-     const  int64_t  i12 = i03 % ne12;
30-     const  int64_t  i11 = i02 % ne11;
44+     div_mod           = fast_div_modulo (tmp, ne01);
45+     const  int64_t  i01 = div_mod.y ;
46+     tmp               = div_mod.x ;
47+ 
48+     div_mod           = fast_div_modulo (tmp, ne02);
49+     const  int64_t  i02 = div_mod.y ;
50+     const  int64_t  i03 = div_mod.x ;
51+ 
52+     const  int64_t  i12 = fastmodulo ((uint32_t ) i03, ne12_fd);
53+     const  int64_t  i11 = fastmodulo ((uint32_t ) i02, ne11_fd);
3154    const  int64_t  i10 = i01;
3255
3356    const  int64_t  dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
@@ -41,6 +64,8 @@ static __global__ void k_set_rows_quant(
4164    quantize_func (src_block, dst_block);
4265
4366    GGML_UNUSED (ne10);
67+     GGML_UNUSED (ne11);
68+     GGML_UNUSED (ne12);
4469    GGML_UNUSED (ne13);
4570}
4671
@@ -71,40 +96,65 @@ static void set_rows_cuda_quant(
7196    const  int64_t  s2  = nb2;
7297    const  int64_t  s3  = nb3;
7398
74-     if  (ne_total > 0 ) {
99+     if  (ne_total > 0  && ne00 > 0  && ne01 > 0  && ne02 > 0  && ne11 > 0  && ne12 > 0 ) {
100+         const  uint3  ne00_fd = init_fastdiv_values ((uint32_t ) ne00);
101+         const  uint3  ne01_fd = init_fastdiv_values ((uint32_t ) ne01);
102+         const  uint3  ne02_fd = init_fastdiv_values ((uint32_t ) ne02);
103+         const  uint3  ne11_fd = init_fastdiv_values ((uint32_t ) ne11);
104+         const  uint3  ne12_fd = init_fastdiv_values ((uint32_t ) ne12);
105+ 
75106        k_set_rows_quant<idx_t , block_type, qk, quantize_func><<<grid_size, block_size, 0 , stream>>> (
76-             src0_d, src1_d, dst_d,
77-             ne00, ne01, ne02, ne03,
78-             ne10, ne11, ne12, ne13,
79-             s01, s02, s03,
80-             s10, s11, s12,
81-             s1, s2, s3);
107+             src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd,
108+             ne01_fd, ne02_fd, ne11_fd, ne12_fd);
82109    }
83110}
84111
85- template <typename  src_t , typename  idx_t , typename  dst_t >
86- static  __global__  void  k_set_rows (
87-         const  src_t  * __restrict__  src0, const  idx_t  * __restrict__  src1, dst_t  * __restrict__  dst,
88-         const  int64_t  ne00, const  int64_t  ne01, const  int64_t  ne02, const  int64_t  ne03,
89-         const  int64_t  ne10, const  int64_t  ne11, const  int64_t  ne12, const  int64_t  ne13,
90-         const  int64_t  s01, const  int64_t  s02, const  int64_t  s03,
91-         const  int64_t  s10, const  int64_t  s11, const  int64_t  s12,
92-         const  int64_t  s1, const  int64_t  s2, const  int64_t  s3) {
93- 
112+ template  <typename  src_t , typename  idx_t , typename  dst_t >
113+ static  __global__  void  k_set_rows (const  src_t  * __restrict__  src0,
114+                                   const  idx_t  * __restrict__  src1,
115+                                   dst_t  * __restrict__  dst,
116+                                   const  int64_t  ne_total,
117+                                   const  int64_t  ne10,
118+                                   const  int64_t  ne11,
119+                                   const  int64_t  ne12,
120+                                   const  int64_t  ne13,
121+                                   const  int64_t  s01,
122+                                   const  int64_t  s02,
123+                                   const  int64_t  s03,
124+                                   const  int64_t  s10,
125+                                   const  int64_t  s11,
126+                                   const  int64_t  s12,
127+                                   const  int64_t  s1,
128+                                   const  int64_t  s2,
129+                                   const  int64_t  s3,
130+                                   const  uint3    ne00,
131+                                   const  uint3    ne01,
132+                                   const  uint3    ne02,
133+                                   const  uint3    ne11_fd,
134+                                   const  uint3    ne12_fd) {
94135    const  int64_t  i = int64_t (blockDim .x ) * blockIdx .x  + threadIdx .x ;
95-     const  int64_t  ne_total = ne00 * ne01 * ne02 * ne03;
96136
97137    if  (i >= ne_total) {
98138        return ;
99139    }
100140
101-     const  int64_t  i03 = i / (ne00 * ne01 * ne02);
102-     const  int64_t  i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
103-     const  int64_t  i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
104-     const  int64_t  i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
141+     uint32_t  tmp = (uint32_t ) i;
142+     uint2     div_mod;
143+ 
144+     div_mod           = fast_div_modulo (tmp, ne00);
145+     const  int64_t  i00 = div_mod.y ;
146+     tmp               = div_mod.x ;
105147
106-     const  int64_t  i12 = i03 % ne12;
107-     const  int64_t  i11 = i02 % ne11;
148+     div_mod           = fast_div_modulo (tmp, ne01);
149+     const  int64_t  i01 = div_mod.y ;
150+     tmp               = div_mod.x ;
151+ 
152+     div_mod           = fast_div_modulo (tmp, ne02);
153+     const  int64_t  i02 = div_mod.y ;
154+     const  int64_t  i03 = div_mod.x ;
155+ 
156+     const  int64_t  i12 = fastmodulo ((uint32_t ) i03, ne12_fd);
157+     const  int64_t  i11 = fastmodulo ((uint32_t ) i02, ne11_fd);
108158    const  int64_t  i10 = i01;
109159
110160    const  int64_t  dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
@@ -115,6 +165,8 @@ static __global__ void k_set_rows(
115165    dst_row_ptr[i00] = ggml_cuda_cast<dst_t >(src0_row[i00]);
116166
117167    GGML_UNUSED (ne10);
168+     GGML_UNUSED (ne11);
169+     GGML_UNUSED (ne12);
118170    GGML_UNUSED (ne13);
119171}
120172
@@ -144,14 +196,16 @@ static void set_rows_cuda(
144196    const  int64_t  s2  = nb2/sizeof (dst_t );
145197    const  int64_t  s3  = nb3/sizeof (dst_t );
146198
147-     if  (ne_total > 0 ) {
148-         k_set_rows<<<grid_size, block_size, 0 , stream>>> (
149-             src0_d, src1_d, dst_d,
150-             ne00, ne01, ne02, ne03,
151-             ne10, ne11, ne12, ne13,
152-             s01, s02, s03,
153-             s10, s11, s12,
154-             s1, s2, s3);
199+     if  (ne_total > 0  && ne00 > 0  && ne01 > 0  && ne02 > 0  && ne11 > 0  && ne12 > 0 ) {
200+         const  uint3  ne00_fd = init_fastdiv_values ((uint32_t ) ne00);
201+         const  uint3  ne01_fd = init_fastdiv_values ((uint32_t ) ne01);
202+         const  uint3  ne02_fd = init_fastdiv_values ((uint32_t ) ne02);
203+         const  uint3  ne11_fd = init_fastdiv_values ((uint32_t ) ne11);
204+         const  uint3  ne12_fd = init_fastdiv_values ((uint32_t ) ne12);
205+ 
206+         k_set_rows<<<grid_size, block_size, 0 , stream>>> (src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01,
207+                                                          s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd,
208+                                                          ne11_fd, ne12_fd);
155209    }
156210}
157211
0 commit comments