@@ -23,28 +23,44 @@ static __device__ __forceinline__ float op_div(const float a, const float b) {
2323 return a / b;
2424}
2525
26-
27-
28- template <float (*bin_op)(const float , const float ), typename src0_t , typename src1_t , typename dst_t , typename ... src1_ptrs>
29- static __global__ void k_bin_bcast (const src0_t * src0, const src1_t * src1, dst_t * dst,
30- const int ne0, const int ne1, const int ne2, const int ne3,
31- const int ne10, const int ne11, const int ne12, const int ne13,
32- /* int s0, */ const int s1, const int s2, const int s3,
33- /* int s00,*/ const int s01, const int s02, const int s03,
34- /* int s10,*/ const int s11, const int s12, const int s13,
35- src1_ptrs... src1s) {
36- const int i0s = blockDim .x *blockIdx .x + threadIdx .x ;
37- const int i1 = (blockDim .y *blockIdx .y + threadIdx .y );
38- const int i2 = (blockDim .z *blockIdx .z + threadIdx .z ) / ne3;
39- const int i3 = (blockDim .z *blockIdx .z + threadIdx .z ) % ne3;
40-
41- if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
26+ template <float (*bin_op)(const float , const float ),
27+ typename src0_t ,
28+ typename src1_t ,
29+ typename dst_t ,
30+ typename ... src1_ptrs>
31+ static __global__ void k_bin_bcast (const src0_t * src0,
32+ const src1_t * src1,
33+ dst_t * dst,
34+ const int ne0,
35+ const int ne1,
36+ const int ne2,
37+ const uint3 ne3_fastdiv,
38+ const uint3 ne10_fastdiv,
39+ const uint3 ne11_fastdiv,
40+ const uint3 ne12_fastdiv,
41+ const uint3 ne13_fastdiv,
42+ /* int s0, */ const int s1,
43+ const int s2,
44+ const int s3,
45+ /* int s00,*/ const int s01,
46+ const int s02,
47+ const int s03,
48+ /* int s10,*/ const int s11,
49+ const int s12,
50+ const int s13,
51+ src1_ptrs... src1s) {
52+ const uint32_t i0s = blockDim .x * blockIdx .x + threadIdx .x ;
53+ const uint32_t i1 = (blockDim .y * blockIdx .y + threadIdx .y );
54+ const uint32_t i2 = fastdiv ((blockDim .z * blockIdx .z + threadIdx .z ), ne3_fastdiv);
55+ const uint32_t i3 = (blockDim .z * blockIdx .z + threadIdx .z ) - (i2 * ne3_fastdiv.z );
56+
57+ if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3_fastdiv.z ) {
4258 return ;
4359 }
4460
45- const int i11 = i1 % ne11 ;
46- const int i12 = i2 % ne12 ;
47- const int i13 = i3 % ne13 ;
61+ const uint32_t i11 = fastmodulo (i1, ne11_fastdiv) ;
62+ const uint32_t i12 = fastmodulo (i2, ne12_fastdiv) ;
63+ const uint32_t i13 = fastmodulo (i3, ne13_fastdiv) ;
4864
4965 const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
5066 const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
@@ -53,8 +69,8 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
5369 const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr ;
5470 dst_t * dst_row = dst + i_dst;
5571
56- for (int i0 = i0s; i0 < ne0; i0 += blockDim .x * gridDim .x ) {
57- const int i10 = i0 % ne10 ;
72+ for (int i0 = i0s; i0 < ne0; i0 += blockDim .x * gridDim .x ) {
73+ const uint32_t i10 = fastmodulo (i0, ne10_fastdiv) ;
5874
5975 float result = src0_row ? (float ) src0_row[i0] : 0 .0f ;
6076 if constexpr (sizeof ...(src1_ptrs) > 0 ) {
@@ -67,28 +83,48 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
6783 }
6884}
6985
70- template <float (*bin_op)(const float , const float ), typename src0_t , typename src1_t , typename dst_t , typename ... src1_ptrs>
71- static __global__ void k_bin_bcast_unravel (const src0_t * src0, const src1_t * src1, dst_t * dst,
72- const int ne0, const int ne1, const int ne2,const int ne3,
73- const int ne10, const int ne11, const int ne12, const int ne13,
74- /* int s0, */ const int s1, const int s2, const int s3,
75- /* int s00,*/ const int s01, const int s02, const int s03,
76- /* int s10,*/ const int s11, const int s12, const int s13,
77- src1_ptrs ... src1s) {
86+ template <float (*bin_op)(const float , const float ),
87+ typename src0_t ,
88+ typename src1_t ,
89+ typename dst_t ,
90+ typename ... src1_ptrs>
91+ static __global__ void k_bin_bcast_unravel (const src0_t * src0,
92+ const src1_t * src1,
93+ dst_t * dst,
94+ const uint3 ne0_fastdiv,
95+ const uint3 ne1_fastdiv,
96+ const uint3 ne2_fastdiv,
97+ const uint32_t ne3,
98+ const uint3 ne012_fastdiv,
99+ const uint3 ne01_fastdiv,
100+ const uint3 ne10_fastdiv,
101+ const uint3 ne11_fastdiv,
102+ const uint3 ne12_fastdiv,
103+ const uint3 ne13_fastdiv,
104+ /* int s0, */ const int s1,
105+ const int s2,
106+ const int s3,
107+ /* int s00,*/ const int s01,
108+ const int s02,
109+ const int s03,
110+ /* int s10,*/ const int s11,
111+ const int s12,
112+ const int s13,
113+ src1_ptrs... src1s) {
78114 const int i = blockDim .x *blockIdx .x + threadIdx .x ;
79115
80- const int i3 = i/(ne2*ne1*ne0 );
81- const int i2 = (i/(ne1*ne0)) % ne2 ;
82- const int i1 = (i/ne0) % ne1 ;
83- const int i0 = i % ne0 ;
116+ const uint32_t i3 = fastdiv (i, ne012_fastdiv );
117+ const uint32_t i2 = fastmodulo ( fastdiv (i, ne01_fastdiv), ne2_fastdiv) ;
118+ const uint32_t i1 = fastmodulo ( fastdiv (i, ne0_fastdiv), ne1_fastdiv) ;
119+ const uint32_t i0 = fastmodulo (i, ne0_fastdiv) ;
84120
85- if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
121+ if (i0 >= ne0_fastdiv. z || i1 >= ne1_fastdiv. z || i2 >= ne2_fastdiv. z || i3 >= ne3) {
86122 return ;
87123 }
88124
89- const int i11 = i1 % ne11 ;
90- const int i12 = i2 % ne12 ;
91- const int i13 = i3 % ne13 ;
125+ const int i11 = fastmodulo (i1, ne11_fastdiv) ;
126+ const int i12 = fastmodulo (i2, ne12_fastdiv) ;
127+ const int i13 = fastmodulo (i3, ne13_fastdiv) ;
92128
93129 const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
94130 const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
@@ -97,7 +133,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t *
97133 const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr ;
98134 dst_t * dst_row = dst + i_dst;
99135
100- const int i10 = i0 % ne10 ;
136+ const int i10 = fastmodulo (i0, ne10_fastdiv) ;
101137
102138 float result = src0_row ? (float ) src0_row[i0] : 0 .0f ;
103139 if constexpr (sizeof ...(src1_ptrs) > 0 ) {
@@ -170,11 +206,6 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
170206 // int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
171207 // int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
172208
173- int64_t ne10 = cne1[0 ];
174- int64_t ne11 = cne1[1 ];
175- int64_t ne12 = cne1[2 ];
176- int64_t ne13 = cne1[3 ];
177-
178209 size_t nb0 = cnb[0 ];
179210 size_t nb1 = cnb[1 ];
180211 size_t nb2 = cnb[2 ];
@@ -233,48 +264,53 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
233264 block_dims.y = std::min<unsigned int >(ne1, block_size / block_dims.x );
234265 block_dims.z = std::min (std::min<unsigned int >(ne2 * ne3, block_size / block_dims.x / block_dims.y ), 64U );
235266
236- dim3 block_nums ((hne0 + block_dims.x - 1 ) / block_dims.x ,
237- (ne1 + block_dims.y - 1 ) / block_dims.y ,
267+ dim3 block_nums ((hne0 + block_dims.x - 1 ) / block_dims.x , (ne1 + block_dims.y - 1 ) / block_dims.y ,
238268 (ne2 * ne3 + block_dims.z - 1 ) / block_dims.z );
239269
270+
271+ const uint3 ne10_fastdiv = init_fastdiv_values ((uint32_t ) cne1[0 ]);
272+ const uint3 ne11_fastdiv = init_fastdiv_values ((uint32_t ) cne1[1 ]);
273+ const uint3 ne12_fastdiv = init_fastdiv_values ((uint32_t ) cne1[2 ]);
274+ const uint3 ne13_fastdiv = init_fastdiv_values ((uint32_t ) cne1[3 ]);
275+
240276 if (block_nums.z > 65535 ) {
241- int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1 ) / block_size;
277+ int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1 ) / block_size;
278+ const uint3 ne0_fastdiv = init_fastdiv_values ((uint32_t ) ne0);
279+ const uint3 ne1_fastdiv = init_fastdiv_values ((uint32_t ) ne1);
280+ const uint3 ne2_fastdiv = init_fastdiv_values ((uint32_t ) ne2);
281+ const uint3 ne012_fastdiv = init_fastdiv_values ((uint32_t ) (ne0 * ne1 * ne2));
282+ const uint3 ne01_fastdiv = init_fastdiv_values ((uint32_t ) (ne0 * ne1));
242283 if constexpr (sizeof ...(I) > 0 ) {
243- k_bin_bcast_unravel<bin_op, src0_t , src1_t , dst_t >
244- <<<block_num, block_size, 0 , stream>>> (src0_dd, src1_dd, dst_dd,
245- ne0, ne1, ne2, ne3,
246- ne10, ne11, ne12, ne13,
247- /* s0, */ s1, s2, s3,
248- /* s00,*/ s01, s02, s03,
249- /* s10,*/ s11, s12,s13,
250- (const src1_t *) dst->src [I + 1 ]->data ...);
284+ k_bin_bcast_unravel<bin_op, src0_t , src1_t , dst_t ><<<block_num, block_size, 0 , stream>>> (
285+ src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, ne012_fastdiv, ne01_fastdiv,
286+ ne10_fastdiv, ne11_fastdiv, ne12_fastdiv, ne13_fastdiv,
287+ /* s0, */ s1, s2, s3,
288+ /* s00,*/ s01, s02, s03,
289+ /* s10,*/ s11, s12, s13, (const src1_t *) dst->src [I + 1 ]->data ...);
251290 } else {
252- k_bin_bcast_unravel<bin_op, src0_t , src1_t , dst_t >
253- <<<block_num, block_size, 0 , stream>>> (src0_dd, src1_dd, dst_dd,
254- ne0, ne1, ne2, ne3,
255- ne10, ne11, ne12, ne13,
256- /* s0, */ s1, s2, s3,
257- /* s00,*/ s01, s02, s03,
258- /* s10,*/ s11, s12,s13);
291+ k_bin_bcast_unravel<bin_op, src0_t , src1_t , dst_t ><<<block_num, block_size, 0 , stream>>> (
292+ src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, ne012_fastdiv, ne01_fastdiv,
293+ ne10_fastdiv, ne11_fastdiv, ne12_fastdiv, ne13_fastdiv,
294+ /* s0, */ s1, s2, s3,
295+ /* s00,*/ s01, s02, s03,
296+ /* s10,*/ s11, s12, s13);
259297 }
260298 } else {
299+ const uint3 ne3_fastdiv = init_fastdiv_values ((uint32_t ) ne3);
261300 if constexpr (sizeof ...(I) > 0 ) {
262- k_bin_bcast<bin_op, src0_t , src1_t , dst_t >
263- <<<block_nums, block_dims, 0 , stream>>> (src0_dd, src1_dd, dst_dd,
264- ne0, ne1, ne2, ne3,
265- ne10, ne11, ne12, ne13,
266- /* s0, */ s1, s2, s3,
267- /* s00,*/ s01, s02, s03,
268- /* s10,*/ s11, s12,s13,
269- (const src1_t *) dst->src [I + 1 ]->data ...);
301+ k_bin_bcast<bin_op, src0_t , src1_t , dst_t ><<<block_nums, block_dims, 0 , stream>>> (
302+ src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10_fastdiv, ne11_fastdiv, ne12_fastdiv,
303+ ne13_fastdiv,
304+ /* s0, */ s1, s2, s3,
305+ /* s00,*/ s01, s02, s03,
306+ /* s10,*/ s11, s12, s13, (const src1_t *) dst->src [I + 1 ]->data ...);
270307 } else {
271308 k_bin_bcast<bin_op, src0_t , src1_t , dst_t >
272- <<<block_nums, block_dims, 0 , stream>>> (src0_dd, src1_dd, dst_dd,
273- ne0, ne1, ne2, ne3,
274- ne10, ne11, ne12, ne13,
275- /* s0, */ s1, s2, s3,
276- /* s00,*/ s01, s02, s03,
277- /* s10,*/ s11, s12,s13);
309+ <<<block_nums, block_dims, 0 , stream>>> (src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv,
310+ ne10_fastdiv, ne11_fastdiv, ne12_fastdiv, ne13_fastdiv,
311+ /* s0, */ s1, s2, s3,
312+ /* s00,*/ s01, s02, s03,
313+ /* s10,*/ s11, s12, s13);
278314 }
279315 }
280316 }
0 commit comments