@@ -34,11 +34,11 @@ static __global__ void k_bin_bcast(const src0_t * src0,
3434 const int ne0,
3535 const int ne1,
3636 const int ne2,
37- const uint3 ne3_fastdiv ,
38- const uint3 ne10_fastdiv ,
39- const uint3 ne11_fastdiv ,
40- const uint3 ne12_fastdiv ,
41- const uint3 ne13_fastdiv ,
37+ const uint3 ne3 ,
38+ const uint3 ne10 ,
39+ const uint3 ne11 ,
40+ const uint3 ne12 ,
41+ const uint3 ne13 ,
4242 /* int s0, */ const int s1,
4343 const int s2,
4444 const int s3,
@@ -51,16 +51,16 @@ static __global__ void k_bin_bcast(const src0_t * src0,
5151 src1_ptrs... src1s) {
5252 const uint32_t i0s = blockDim .x * blockIdx .x + threadIdx .x ;
5353 const uint32_t i1 = (blockDim .y * blockIdx .y + threadIdx .y );
54- const uint32_t i2 = fastdiv ((blockDim .z * blockIdx .z + threadIdx .z ), ne3_fastdiv );
55- const uint32_t i3 = (blockDim .z * blockIdx .z + threadIdx .z ) - (i2 * ne3_fastdiv .z );
54+ const uint32_t i2 = fastdiv ((blockDim .z * blockIdx .z + threadIdx .z ), ne3 );
55+ const uint32_t i3 = (blockDim .z * blockIdx .z + threadIdx .z ) - (i2 * ne3 .z );
5656
57- if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3_fastdiv .z ) {
57+ if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3 .z ) {
5858 return ;
5959 }
6060
61- const uint32_t i11 = fastmodulo (i1, ne11_fastdiv );
62- const uint32_t i12 = fastmodulo (i2, ne12_fastdiv );
63- const uint32_t i13 = fastmodulo (i3, ne13_fastdiv );
61+ const uint32_t i11 = fastmodulo (i1, ne11 );
62+ const uint32_t i12 = fastmodulo (i2, ne12 );
63+ const uint32_t i13 = fastmodulo (i3, ne13 );
6464
6565 const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
6666 const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
@@ -70,7 +70,7 @@ static __global__ void k_bin_bcast(const src0_t * src0,
7070 dst_t * dst_row = dst + i_dst;
7171
7272 for (int i0 = i0s; i0 < ne0; i0 += blockDim .x * gridDim .x ) {
73- const uint32_t i10 = fastmodulo (i0, ne10_fastdiv );
73+ const uint32_t i10 = fastmodulo (i0, ne10 );
7474
7575 float result = src0_row ? (float ) src0_row[i0] : 0 .0f ;
7676 if constexpr (sizeof ...(src1_ptrs) > 0 ) {
@@ -91,16 +91,16 @@ template <float (*bin_op)(const float, const float),
9191static __global__ void k_bin_bcast_unravel (const src0_t * src0,
9292 const src1_t * src1,
9393 dst_t * dst,
94- const uint3 ne0_fastdiv ,
95- const uint3 ne1_fastdiv ,
96- const uint3 ne2_fastdiv ,
94+ const uint3 ne0 ,
95+ const uint3 ne1 ,
96+ const uint3 ne2 ,
9797 const uint32_t ne3,
98- const uint3 ne012_fastdiv ,
99- const uint3 ne01_fastdiv ,
100- const uint3 ne10_fastdiv ,
101- const uint3 ne11_fastdiv ,
102- const uint3 ne12_fastdiv ,
103- const uint3 ne13_fastdiv ,
98+ const uint3 prod012 ,
99+ const uint3 prod01 ,
100+ const uint3 ne10 ,
101+ const uint3 ne11 ,
102+ const uint3 ne12 ,
103+ const uint3 ne13 ,
104104 /* int s0, */ const int s1,
105105 const int s2,
106106 const int s3,
@@ -113,18 +113,18 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
113113 src1_ptrs... src1s) {
114114 const int i = blockDim .x *blockIdx .x + threadIdx .x ;
115115
116- const uint32_t i3 = fastdiv (i, ne012_fastdiv );
117- const uint32_t i2 = fastmodulo ( fastdiv (i, ne01_fastdiv), ne2_fastdiv );
118- const uint32_t i1 = fastmodulo ( fastdiv (i, ne0_fastdiv), ne1_fastdiv );
119- const uint32_t i0 = fastmodulo (i, ne0_fastdiv) ;
116+ const uint32_t i3 = fastdiv (i, prod012 );
117+ const uint32_t i2 = fastdiv (i - i3 * prod012. z , prod01 );
118+ const uint32_t i1 = fastdiv (i - i3 * prod012. z - i2 * prod01. z , ne0 );
119+ const uint32_t i0 = i - i3 * prod012. z - i2 * prod01. z - i1 * ne0. z ;
120120
121- if (i0 >= ne0_fastdiv .z || i1 >= ne1_fastdiv .z || i2 >= ne2_fastdiv .z || i3 >= ne3) {
121+ if (i0 >= ne0 .z || i1 >= ne1 .z || i2 >= ne2 .z || i3 >= ne3) {
122122 return ;
123123 }
124124
125- const int i11 = fastmodulo (i1, ne11_fastdiv );
126- const int i12 = fastmodulo (i2, ne12_fastdiv );
127- const int i13 = fastmodulo (i3, ne13_fastdiv );
125+ const int i11 = fastmodulo (i1, ne11 );
126+ const int i12 = fastmodulo (i2, ne12 );
127+ const int i13 = fastmodulo (i3, ne13 );
128128
129129 const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
130130 const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
@@ -133,7 +133,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
133133 const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr ;
134134 dst_t * dst_row = dst + i_dst;
135135
136- const int i10 = fastmodulo (i0, ne10_fastdiv );
136+ const int i10 = fastmodulo (i0, ne10 );
137137
138138 float result = src0_row ? (float ) src0_row[i0] : 0 .0f ;
139139 if constexpr (sizeof ...(src1_ptrs) > 0 ) {
@@ -267,50 +267,48 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
267267 dim3 block_nums ((hne0 + block_dims.x - 1 ) / block_dims.x , (ne1 + block_dims.y - 1 ) / block_dims.y ,
268268 (ne2 * ne3 + block_dims.z - 1 ) / block_dims.z );
269269
270-
271- const uint3 ne10_fastdiv = init_fastdiv_values ((uint32_t ) cne1[0 ]);
272- const uint3 ne11_fastdiv = init_fastdiv_values ((uint32_t ) cne1[1 ]);
273- const uint3 ne12_fastdiv = init_fastdiv_values ((uint32_t ) cne1[2 ]);
274- const uint3 ne13_fastdiv = init_fastdiv_values ((uint32_t ) cne1[3 ]);
270+ const uint3 ne10 = init_fastdiv_values ((uint32_t ) cne1[0 ]);
271+ const uint3 ne11 = init_fastdiv_values ((uint32_t ) cne1[1 ]);
272+ const uint3 ne12 = init_fastdiv_values ((uint32_t ) cne1[2 ]);
273+ const uint3 ne13 = init_fastdiv_values ((uint32_t ) cne1[3 ]);
275274
276275 if (block_nums.z > 65535 ) {
277- int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1 ) / block_size;
278- const uint3 ne0_fastdiv = init_fastdiv_values ((uint32_t ) ne0);
279- const uint3 ne1_fastdiv = init_fastdiv_values ((uint32_t ) ne1);
280- const uint3 ne2_fastdiv = init_fastdiv_values ((uint32_t ) ne2);
281- const uint3 ne012_fastdiv = init_fastdiv_values ((uint32_t ) (ne0 * ne1 * ne2));
282- const uint3 ne01_fastdiv = init_fastdiv_values ((uint32_t ) (ne0 * ne1));
276+ int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1 ) / block_size;
277+ const uint3 prod012 = init_fastdiv_values ((uint32_t ) (ne0 * ne1 * ne2));
278+ const uint3 prod01 = init_fastdiv_values ((uint32_t ) (ne0 * ne1));
279+ const uint3 ne0_fastdiv = init_fastdiv_values ((uint32_t ) ne0);
280+ const uint3 ne1_fastdiv = init_fastdiv_values ((uint32_t ) ne1);
281+ const uint3 ne2_fastdiv = init_fastdiv_values ((uint32_t ) ne2);
282+
283283 if constexpr (sizeof ...(I) > 0 ) {
284284 k_bin_bcast_unravel<bin_op, src0_t , src1_t , dst_t ><<<block_num, block_size, 0 , stream>>> (
285- src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, ne012_fastdiv, ne01_fastdiv ,
286- ne10_fastdiv, ne11_fastdiv, ne12_fastdiv, ne13_fastdiv ,
285+ src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod012, prod01, ne10, ne11 ,
286+ ne12, ne13 ,
287287 /* s0, */ s1, s2, s3,
288288 /* s00,*/ s01, s02, s03,
289289 /* s10,*/ s11, s12, s13, (const src1_t *) dst->src [I + 1 ]->data ...);
290290 } else {
291- k_bin_bcast_unravel<bin_op, src0_t , src1_t , dst_t ><<<block_num, block_size, 0 , stream>>> (
292- src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, ne012_fastdiv, ne01_fastdiv ,
293- ne10_fastdiv, ne11_fastdiv, ne12_fastdiv, ne13_fastdiv ,
294- /* s0, */ s1, s2, s3,
295- /* s00,*/ s01, s02, s03,
296- /* s10,*/ s11, s12, s13);
291+ k_bin_bcast_unravel<bin_op, src0_t , src1_t , dst_t >
292+ <<<block_num, block_size, 0 , stream>>> (src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv ,
293+ ne2_fastdiv, ne3, prod012, prod01, ne10, ne11, ne12, ne13 ,
294+ /* s0, */ s1, s2, s3,
295+ /* s00,*/ s01, s02, s03,
296+ /* s10,*/ s11, s12, s13);
297297 }
298298 } else {
299299 const uint3 ne3_fastdiv = init_fastdiv_values ((uint32_t ) ne3);
300300 if constexpr (sizeof ...(I) > 0 ) {
301301 k_bin_bcast<bin_op, src0_t , src1_t , dst_t ><<<block_nums, block_dims, 0 , stream>>> (
302- src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10_fastdiv, ne11_fastdiv, ne12_fastdiv,
303- ne13_fastdiv,
302+ src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
304303 /* s0, */ s1, s2, s3,
305304 /* s00,*/ s01, s02, s03,
306305 /* s10,*/ s11, s12, s13, (const src1_t *) dst->src [I + 1 ]->data ...);
307306 } else {
308- k_bin_bcast<bin_op, src0_t , src1_t , dst_t >
309- <<<block_nums, block_dims, 0 , stream>>> (src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv,
310- ne10_fastdiv, ne11_fastdiv, ne12_fastdiv, ne13_fastdiv,
311- /* s0, */ s1, s2, s3,
312- /* s00,*/ s01, s02, s03,
313- /* s10,*/ s11, s12, s13);
307+ k_bin_bcast<bin_op, src0_t , src1_t , dst_t ><<<block_nums, block_dims, 0 , stream>>> (
308+ src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
309+ /* s0, */ s1, s2, s3,
310+ /* s00,*/ s01, s02, s03,
311+ /* s10,*/ s11, s12, s13);
314312 }
315313 }
316314 }
0 commit comments