@@ -353,7 +353,45 @@ struct vk_op_unary_push_constants {
353353    uint32_t  ne10; uint32_t  ne11; uint32_t  ne12; uint32_t  ne13; uint32_t  nb10; uint32_t  nb11; uint32_t  nb12; uint32_t  nb13;
354354    uint32_t  d_offset;
355355    float  param1; float  param2;
356+     uint32_t  ne0_012mp; uint32_t  ne0_012L;
357+     uint32_t  ne0_01mp;  uint32_t  ne0_01L;
358+     uint32_t  ne0_0mp;   uint32_t  ne0_0L;
359+     uint32_t  ne1_012mp; uint32_t  ne1_012L;
360+     uint32_t  ne1_01mp;  uint32_t  ne1_01L;
361+     uint32_t  ne1_0mp;   uint32_t  ne1_0L;
356362};
363+ static_assert (sizeof (vk_op_unary_push_constants) <= 128 , " sizeof(vk_op_unary_push_constants) must be <= 128"  );
364+ 
365+ //  See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
366+ //  Precompute mp (m' in the paper) and L such that division
367+ //  can be computed using a multiply (high 32b of 64b result)
368+ //  and a shift:
369+ // 
370+ //  n/d = (mulhi(n, mp) + n) >> L;
371+ void  init_fastdiv_values (uint32_t  d, uint32_t  &mp, uint32_t  &L)
372+ {
373+     //  compute L = ceil(log2(d));
374+     L = 0 ;
375+     while  (L < 32  && (uint32_t {1 } << L) < d) {
376+         L++;
377+     }
378+ 
379+     mp = (uint32_t )((uint64_t {1 } << 32 ) * ((uint64_t {1 } << L) - d) / d + 1 );
380+ }
381+ 
382+ template  <typename  T> void  init_pushconst_fastdiv (T &p) {
383+     static_assert (!std::is_const<T>::value, " unexpected type"  );
384+ }
385+ 
386+ template  <> void  init_pushconst_fastdiv (vk_op_unary_push_constants &p) {
387+     //  Compute magic values to divide by these six numbers.
388+     init_fastdiv_values (p.ne02 *p.ne01 *p.ne00 ,  p.ne0_012mp ,    p.ne0_012L );
389+     init_fastdiv_values (p.ne01 *p.ne00 ,         p.ne0_01mp ,     p.ne0_01L );
390+     init_fastdiv_values (p.ne00 ,                p.ne0_0mp ,      p.ne0_0L );
391+     init_fastdiv_values (p.ne12 *p.ne11 *p.ne10 ,  p.ne1_012mp ,    p.ne1_012L );
392+     init_fastdiv_values (p.ne11 *p.ne10 ,         p.ne1_01mp ,     p.ne1_01L );
393+     init_fastdiv_values (p.ne10 ,                p.ne1_0mp ,      p.ne1_0L );
394+ }
357395
358396struct  vk_op_binary_push_constants  {
359397    uint32_t  ne;
@@ -2914,13 +2952,14 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
29142952        elements = { ne, 1 , 1  };
29152953    }
29162954
2917-     const   vk_op_unary_push_constants pc = {
2955+     vk_op_unary_push_constants pc = {
29182956        (uint32_t )ne,
29192957        (uint32_t )tensor->ne [0 ], (uint32_t )tensor->ne [1 ], (uint32_t )tensor->ne [2 ], (uint32_t )tensor->ne [3 ], (uint32_t )tensor->nb [0 ] / tensor_type_size, (uint32_t )tensor->nb [1 ] / tensor_type_size, (uint32_t )tensor->nb [2 ] / tensor_type_size, (uint32_t )tensor->nb [3 ] / tensor_type_size,
29202958        (uint32_t )tensor->ne [0 ], (uint32_t )tensor->ne [1 ], (uint32_t )tensor->ne [2 ], (uint32_t )tensor->ne [3 ],                       1                    , (uint32_t )tensor->ne [0 ]                   , (uint32_t )(tensor->ne [0 ] * tensor->ne [1 ]) , (uint32_t )(tensor->ne [0 ] * tensor->ne [1 ] * tensor->ne [2 ]),
29212959        0 ,
29222960        0 .0f , 0 .0f ,
29232961    };
2962+     init_pushconst_fastdiv (pc);
29242963    ggml_vk_sync_buffers (subctx);
29252964    ggml_vk_dispatch_pipeline (ctx, subctx, pipeline, { in, out }, sizeof (vk_op_unary_push_constants), &pc, elements);
29262965}
@@ -4125,7 +4164,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
41254164}
41264165
41274166template <typename  PC>
4128- static  void  ggml_vk_op_f32 (ggml_backend_vk_context * ctx, vk_context& subctx, const  ggml_tensor * src0, const  ggml_tensor * src1, const  ggml_tensor * src2, ggml_tensor * dst, ggml_op op, const   PC&& pc, bool  dryrun = false ) {
4167+ static  void  ggml_vk_op_f32 (ggml_backend_vk_context * ctx, vk_context& subctx, const  ggml_tensor * src0, const  ggml_tensor * src1, const  ggml_tensor * src2, ggml_tensor * dst, ggml_op op, PC&& pc, bool  dryrun = false ) {
41294168    VK_LOG_DEBUG (" ggml_vk_op_f32(("   << src0 << " , name="   << src0->name  << " , type="   << src0->type  << " , ne0="   << src0->ne [0 ] << " , ne1="   << src0->ne [1 ] << " , ne2="   << src0->ne [2 ] << " , ne3="   << src0->ne [3 ] << " , nb0="   << src0->nb [0 ] << " , nb1="   << src0->nb [1 ] << " , nb2="   << src0->nb [2 ] << " , nb3="   << src0->nb [3 ];
41304169    if  (src1 != nullptr ) {
41314170        std::cerr << " ), ("   << src1 << " , name="   << src1->name  << " , type="   << src1->type  << " , ne0="   << src1->ne [0 ] << " , ne1="   << src1->ne [1 ] << " , ne2="   << src1->ne [2 ] << " , ne3="   << src1->ne [3 ] << " , nb0="   << src1->nb [0 ] << " , nb1="   << src1->nb [1 ] << " , nb2="   << src1->nb [2 ] << " , nb3="   << src1->nb [3 ];
@@ -4165,6 +4204,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
41654204    const  uint64_t  ned3 = dst->ne [3 ];
41664205    const  uint64_t  ned = ned0 * ned1;
41674206
4207+     init_pushconst_fastdiv (pc);
4208+ 
41684209    vk_pipeline pipeline = ggml_vk_op_get_pipeline (ctx, src0, src1, src2, dst, op);
41694210
41704211    if  (pipeline == nullptr ) {
0 commit comments