@@ -352,7 +352,45 @@ struct vk_op_unary_push_constants {
352352    uint32_t  ne10; uint32_t  ne11; uint32_t  ne12; uint32_t  ne13; uint32_t  nb10; uint32_t  nb11; uint32_t  nb12; uint32_t  nb13;
353353    uint32_t  d_offset;
354354    float  param1; float  param2;
355+     uint32_t  ne0_012mp; uint32_t  ne0_012L;
356+     uint32_t  ne0_01mp;  uint32_t  ne0_01L;
357+     uint32_t  ne0_0mp;   uint32_t  ne0_0L;
358+     uint32_t  ne1_012mp; uint32_t  ne1_012L;
359+     uint32_t  ne1_01mp;  uint32_t  ne1_01L;
360+     uint32_t  ne1_0mp;   uint32_t  ne1_0L;
355361};
362+ static_assert (sizeof (vk_op_unary_push_constants) <= 128 , " sizeof(vk_op_unary_push_constants) must be <= 128"  );
363+ 
364+ //  See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
365+ //  Precompute mp (m' in the paper) and L such that division
366+ //  can be computed using a multiply (high 32b of 64b result)
367+ //  and a shift:
368+ // 
369+ //  n/d = (mulhi(n, mp) + n) >> L;
370+ void  init_fastdiv_values (uint32_t  d, uint32_t  &mp, uint32_t  &L)
371+ {
372+     //  compute L = ceil(log2(d));
373+     L = 0 ;
374+     while  (L < 32  && (uint32_t {1 } << L) < d) {
375+         L++;
376+     }
377+ 
378+     mp = (uint32_t )((uint64_t {1 } << 32 ) * ((uint64_t {1 } << L) - d) / d + 1 );
379+ }
380+ 
381+ template  <typename  T> void  init_pushconst_fastdiv (T &p) {
382+     static_assert (!std::is_const<T>::value, " unexpected type"  );
383+ }
384+ 
385+ template  <> void  init_pushconst_fastdiv (vk_op_unary_push_constants &p) {
386+     //  Compute magic values to divide by these six numbers.
387+     init_fastdiv_values (p.ne02 *p.ne01 *p.ne00 ,  p.ne0_012mp ,    p.ne0_012L );
388+     init_fastdiv_values (p.ne01 *p.ne00 ,         p.ne0_01mp ,     p.ne0_01L );
389+     init_fastdiv_values (p.ne00 ,                p.ne0_0mp ,      p.ne0_0L );
390+     init_fastdiv_values (p.ne12 *p.ne11 *p.ne10 ,  p.ne1_012mp ,    p.ne1_012L );
391+     init_fastdiv_values (p.ne11 *p.ne10 ,         p.ne1_01mp ,     p.ne1_01L );
392+     init_fastdiv_values (p.ne10 ,                p.ne1_0mp ,      p.ne1_0L );
393+ }
356394
357395struct  vk_op_binary_push_constants  {
358396    uint32_t  ne;
@@ -2885,13 +2923,14 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
28852923        elements = { ne, 1 , 1  };
28862924    }
28872925
2888-     const   vk_op_unary_push_constants pc = {
2926+     vk_op_unary_push_constants pc = {
28892927        (uint32_t )ne,
28902928        (uint32_t )tensor->ne [0 ], (uint32_t )tensor->ne [1 ], (uint32_t )tensor->ne [2 ], (uint32_t )tensor->ne [3 ], (uint32_t )tensor->nb [0 ] / tensor_type_size, (uint32_t )tensor->nb [1 ] / tensor_type_size, (uint32_t )tensor->nb [2 ] / tensor_type_size, (uint32_t )tensor->nb [3 ] / tensor_type_size,
28912929        (uint32_t )tensor->ne [0 ], (uint32_t )tensor->ne [1 ], (uint32_t )tensor->ne [2 ], (uint32_t )tensor->ne [3 ],                       1                    , (uint32_t )tensor->ne [0 ]                   , (uint32_t )(tensor->ne [0 ] * tensor->ne [1 ]) , (uint32_t )(tensor->ne [0 ] * tensor->ne [1 ] * tensor->ne [2 ]),
28922930        0 ,
28932931        0 .0f , 0 .0f ,
28942932    };
2933+     init_pushconst_fastdiv (pc);
28952934    ggml_vk_sync_buffers (subctx);
28962935    ggml_vk_dispatch_pipeline (ctx, subctx, pipeline, { in, out }, sizeof (vk_op_unary_push_constants), &pc, elements);
28972936}
@@ -4096,7 +4135,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
40964135}
40974136
40984137template <typename  PC>
4099- static  void  ggml_vk_op_f32 (ggml_backend_vk_context * ctx, vk_context& subctx, const  ggml_tensor * src0, const  ggml_tensor * src1, const  ggml_tensor * src2, ggml_tensor * dst, ggml_op op, const   PC&& pc, bool  dryrun = false ) {
4138+ static  void  ggml_vk_op_f32 (ggml_backend_vk_context * ctx, vk_context& subctx, const  ggml_tensor * src0, const  ggml_tensor * src1, const  ggml_tensor * src2, ggml_tensor * dst, ggml_op op, PC&& pc, bool  dryrun = false ) {
41004139    VK_LOG_DEBUG (" ggml_vk_op_f32(("   << src0 << " , name="   << src0->name  << " , type="   << src0->type  << " , ne0="   << src0->ne [0 ] << " , ne1="   << src0->ne [1 ] << " , ne2="   << src0->ne [2 ] << " , ne3="   << src0->ne [3 ] << " , nb0="   << src0->nb [0 ] << " , nb1="   << src0->nb [1 ] << " , nb2="   << src0->nb [2 ] << " , nb3="   << src0->nb [3 ];
41014140    if  (src1 != nullptr ) {
41024141        std::cerr << " ), ("   << src1 << " , name="   << src1->name  << " , type="   << src1->type  << " , ne0="   << src1->ne [0 ] << " , ne1="   << src1->ne [1 ] << " , ne2="   << src1->ne [2 ] << " , ne3="   << src1->ne [3 ] << " , nb0="   << src1->nb [0 ] << " , nb1="   << src1->nb [1 ] << " , nb2="   << src1->nb [2 ] << " , nb3="   << src1->nb [3 ];
@@ -4136,6 +4175,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
41364175    const  uint64_t  ned3 = dst->ne [3 ];
41374176    const  uint64_t  ned = ned0 * ned1;
41384177
4178+     init_pushconst_fastdiv (pc);
4179+ 
41394180    vk_pipeline pipeline = ggml_vk_op_get_pipeline (ctx, src0, src1, src2, dst, op);
41404181
41414182    if  (pipeline == nullptr ) {
0 commit comments