@@ -95,8 +95,8 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
9595 const uint3 ne1,
9696 const uint3 ne2,
9797 const uint32_t ne3,
98- const uint3 prod012 ,
99- const uint3 prod01 ,
98+ const uint3 prod_012 ,
99+ const uint3 prod_01 ,
100100 const uint3 ne10,
101101 const uint3 ne11,
102102 const uint3 ne12,
@@ -113,10 +113,10 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
113113 src1_ptrs... src1s) {
114114 const int i = blockDim .x *blockIdx .x + threadIdx .x ;
115115
116- const uint32_t i3 = fastdiv (i, prod012 );
117- const uint32_t i2 = fastdiv (i - i3 * prod012 .z , prod01 );
118- const uint32_t i1 = fastdiv (i - i3 * prod012 .z - i2 * prod01 .z , ne0);
119- const uint32_t i0 = i - i3 * prod012 .z - i2 * prod01 .z - i1 * ne0.z ;
116+ const uint32_t i3 = fastdiv (i, prod_012 );
117+ const uint32_t i2 = fastdiv (i - i3 * prod_012 .z , prod_01 );
118+ const uint32_t i1 = fastdiv (i - i3 * prod_012 .z - i2 * prod_01 .z , ne0);
119+ const uint32_t i0 = i - i3 * prod_012 .z - i2 * prod_01 .z - i1 * ne0.z ;
120120
121121 if (i0 >= ne0.z || i1 >= ne1.z || i2 >= ne2.z || i3 >= ne3) {
122122 return ;
@@ -274,23 +274,23 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
274274
275275 if (block_nums.z > 65535 ) {
276276 int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1 ) / block_size;
277- const uint3 prod012 = init_fastdiv_values ((uint32_t ) (ne0 * ne1 * ne2));
278- const uint3 prod01 = init_fastdiv_values ((uint32_t ) (ne0 * ne1));
277+ const uint3 prod_012 = init_fastdiv_values ((uint32_t ) (ne0 * ne1 * ne2));
278+ const uint3 prod_01 = init_fastdiv_values ((uint32_t ) (ne0 * ne1));
279279 const uint3 ne0_fastdiv = init_fastdiv_values ((uint32_t ) ne0);
280280 const uint3 ne1_fastdiv = init_fastdiv_values ((uint32_t ) ne1);
281281 const uint3 ne2_fastdiv = init_fastdiv_values ((uint32_t ) ne2);
282282
283283 if constexpr (sizeof ...(I) > 0 ) {
284284 k_bin_bcast_unravel<bin_op, src0_t , src1_t , dst_t ><<<block_num, block_size, 0 , stream>>> (
285- src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod012, prod01 , ne10, ne11,
285+ src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01 , ne10, ne11,
286286 ne12, ne13,
287287 /* s0, */ s1, s2, s3,
288288 /* s00,*/ s01, s02, s03,
289289 /* s10,*/ s11, s12, s13, (const src1_t *) dst->src [I + 1 ]->data ...);
290290 } else {
291291 k_bin_bcast_unravel<bin_op, src0_t , src1_t , dst_t >
292292 <<<block_num, block_size, 0 , stream>>> (src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
293- ne2_fastdiv, ne3, prod012, prod01 , ne10, ne11, ne12, ne13,
293+ ne2_fastdiv, ne3, prod_012, prod_01 , ne10, ne11, ne12, ne13,
294294 /* s0, */ s1, s2, s3,
295295 /* s00,*/ s01, s02, s03,
296296 /* s10,*/ s11, s12, s13);
0 commit comments