@@ -386,10 +386,13 @@ struct vk_flash_attn_push_constants {
386386    uint32_t  nev3;
387387    uint32_t  nem1;
388388
389+     uint32_t  nb01;
389390    uint32_t  nb02;
390391    uint32_t  nb03;
392+     uint32_t  nb11;
391393    uint32_t  nb12;
392394    uint32_t  nb13;
395+     uint32_t  nb21;
393396    uint32_t  nb22;
394397    uint32_t  nb23;
395398    uint32_t  nb31;
@@ -4809,7 +4812,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
48094812    }
48104813    assert (pipelines);
48114814
4812-     bool  aligned = (KV % pipelines[1 ]->align ) == 0 ;
4815+     const  uint32_t  q_stride = (uint32_t )(nbq1 / ggml_type_size (q->type ));
4816+     const  uint32_t  k_stride = (uint32_t )(nbk1 / ggml_type_size (k->type ));
4817+     const  uint32_t  v_stride = (uint32_t )(nbv1 / ggml_type_size (v->type ));
4818+ 
4819+     bool  aligned = (KV % pipelines[1 ]->align ) == 0  &&
4820+                    //  the "aligned" shader variant will forcibly align strides, for performance
4821+                    (q_stride & 7 ) == 0  && (k_stride & 7 ) == 0  && (v_stride & 7 ) == 0 ;
4822+ 
48134823    vk_pipeline pipeline = pipelines[aligned];
48144824    assert (pipeline);
48154825
@@ -4845,15 +4855,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
48454855
48464856    if  (ctx->device ->uma ) {
48474857        ggml_vk_host_get (ctx->device , q->data , d_Q, q_buf_offset);
4848-         ggml_vk_host_get (ctx->device , k->data , d_K, q_buf_offset );
4849-         ggml_vk_host_get (ctx->device , v->data , d_V, q_buf_offset );
4850-         ggml_vk_host_get (ctx->device , dst->data , d_D, q_buf_offset );
4858+         ggml_vk_host_get (ctx->device , k->data , d_K, k_buf_offset );
4859+         ggml_vk_host_get (ctx->device , v->data , d_V, v_buf_offset );
4860+         ggml_vk_host_get (ctx->device , dst->data , d_D, d_buf_offset );
48514861        Q_uma = d_Q != nullptr ;
48524862        K_uma = d_K != nullptr ;
48534863        V_uma = d_V != nullptr ;
48544864        D_uma = d_D != nullptr ;
48554865        if  (mask) {
4856-             ggml_vk_host_get (ctx->device , mask->data , d_M, q_buf_offset );
4866+             ggml_vk_host_get (ctx->device , mask->data , d_M, m_buf_offset );
48574867            M_uma = d_M != nullptr ;
48584868        }
48594869    }
@@ -4891,7 +4901,18 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
48914901        }
48924902    }
48934903
4894-     const  vk_flash_attn_push_constants pc = { N, KV, (uint32_t )ne1, (uint32_t )ne2, (uint32_t )ne3, (uint32_t )neq2, (uint32_t )neq3, (uint32_t )nek2, (uint32_t )nek3, (uint32_t )nev2, (uint32_t )nev3, nem1, (uint32_t )nbq2, (uint32_t )nbq3, (uint32_t )nbk2, (uint32_t )nbk3, (uint32_t )nbv2, (uint32_t )nbv3, nbm1, scale, max_bias, logit_softcap, mask != nullptr , n_head_log2, m0, m1 };
4904+     const  vk_flash_attn_push_constants pc = { N, KV,
4905+                                               (uint32_t )ne1, (uint32_t )ne2, (uint32_t )ne3,
4906+                                               (uint32_t )neq2, (uint32_t )neq3,
4907+                                               (uint32_t )nek2, (uint32_t )nek3,
4908+                                               (uint32_t )nev2, (uint32_t )nev3,
4909+                                               nem1,
4910+                                               q_stride, (uint32_t )nbq2, (uint32_t )nbq3,
4911+                                               k_stride, (uint32_t )nbk2, (uint32_t )nbk3,
4912+                                               v_stride, (uint32_t )nbv2, (uint32_t )nbv3,
4913+                                               nbm1,
4914+                                               scale, max_bias, logit_softcap,
4915+                                               mask != nullptr , n_head_log2, m0, m1 };
48954916    ggml_vk_dispatch_pipeline (ctx, subctx, pipeline,
48964917                                {
48974918                                    vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
@@ -8668,6 +8689,7 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
86688689    ggml_tensor * src0 = tensor->src [0 ];
86698690    ggml_tensor * src1 = tensor->src [1 ];
86708691    ggml_tensor * src2 = tensor->src [2 ];
8692+     ggml_tensor * src3 = tensor->src [3 ];
86718693
86728694    void  * tensor_data = tensor->data ;
86738695
@@ -8730,6 +8752,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
87308752                        if  (src2 != nullptr ) {
87318753                            std::cerr << " src2="   << src2 << "  src2->name="   << src2->name  << "  op="   << ggml_op_name (src2->op ) << "  type="   << ggml_type_name (src2->type ) << "  ne0="   << src2->ne [0 ] << "  nb0="   << src2->nb [0 ] << "  ne1="   << src2->ne [1 ] << "  nb1="   << src2->nb [1 ] << "  ne2="   << src2->ne [2 ] << "  nb2="   << src2->nb [2 ] << "  ne3="   << src2->ne [3 ] << "  nb3="   << src2->nb [3 ] << "  offset="   << src2->view_offs  << std::endl;
87328754                        }
8755+                         if  (src3 != nullptr ) {
8756+                             std::cerr << " src3="   << src3 << "  src3->name="   << src3->name  << "  op="   << ggml_op_name (src3->op ) << "  type="   << ggml_type_name (src3->type ) << "  ne0="   << src3->ne [0 ] << "  nb0="   << src3->nb [0 ] << "  ne1="   << src3->ne [1 ] << "  nb1="   << src3->nb [1 ] << "  ne2="   << src3->ne [2 ] << "  nb2="   << src3->nb [2 ] << "  ne3="   << src3->ne [3 ] << "  nb3="   << src3->nb [3 ] << "  offset="   << src3->view_offs  << std::endl;
8757+                         }
87338758                        std::cerr << " First error: result="   << first_error_result << "  correct="   << first_error_correct  << "  i3="   << first_error[3 ] << "  i2="   << first_error[2 ] << "  i1="   << first_error[1 ] << "  i0="   << first_error[0 ] << std::endl;
87348759                        std::cerr << std::endl << " Result:"   << std::endl;
87358760                        ggml_vk_print_tensor_area (tensor, tensor_data, i0, i1, i2, i3);
@@ -8774,6 +8799,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
87748799        if  (src2 != nullptr ) {
87758800            std::cerr << " src2="   << src2 << "  op="   << ggml_op_name (src2->op ) << "  type="   << ggml_type_name (src2->type ) << "  ne0="   << src2->ne [0 ] << "  nb0="   << src2->nb [0 ] << "  ne1="   << src2->ne [1 ] << "  nb1="   << src2->nb [1 ] << "  ne2="   << src2->ne [2 ] << "  nb2="   << src2->nb [2 ] << "  ne3="   << src2->ne [3 ] << "  nb3="   << src2->nb [3 ] << "  offset="   << src2->view_offs  << std::endl;
87768801        }
8802+         if  (src3 != nullptr ) {
8803+             std::cerr << " src3="   << src3 << "  op="   << ggml_op_name (src3->op ) << "  type="   << ggml_type_name (src3->type ) << "  ne0="   << src3->ne [0 ] << "  nb0="   << src3->nb [0 ] << "  ne1="   << src3->ne [1 ] << "  nb1="   << src3->nb [1 ] << "  ne2="   << src3->ne [2 ] << "  nb2="   << src3->nb [2 ] << "  ne3="   << src3->ne [3 ] << "  nb3="   << src3->nb [3 ] << "  offset="   << src3->view_offs  << std::endl;
8804+         }
87778805        std::cerr << " First error: result="   << first_error_result << "  correct="   << first_error_correct  << "  i3="   << first_error[3 ] << "  i2="   << first_error[2 ] << "  i1="   << first_error[1 ] << "  i0="   << first_error[0 ] << std::endl;
87788806        std::cerr << std::endl << " Result:"   << std::endl;
87798807        ggml_vk_print_tensor_area (tensor, tensor_data, 5 , 5 , 0 , 0 );
@@ -8796,6 +8824,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
87968824        if  (src2 != nullptr ) {
87978825            std::cerr << " src2="   << src2 << "  op="   << ggml_op_name (src2->op ) << "  type="   << ggml_type_name (src2->type ) << "  ne0="   << src2->ne [0 ] << "  nb0="   << src2->nb [0 ] << "  ne1="   << src2->ne [1 ] << "  nb1="   << src2->nb [1 ] << "  ne2="   << src2->ne [2 ] << "  nb2="   << src2->nb [2 ] << "  ne3="   << src2->ne [3 ] << "  nb3="   << src2->nb [3 ] << "  offset="   << src2->view_offs  << std::endl;
87988826        }
8827+         if  (src3 != nullptr ) {
8828+             std::cerr << " src3="   << src3 << "  op="   << ggml_op_name (src3->op ) << "  type="   << ggml_type_name (src3->type ) << "  ne0="   << src3->ne [0 ] << "  nb0="   << src3->nb [0 ] << "  ne1="   << src3->ne [1 ] << "  nb1="   << src3->nb [1 ] << "  ne2="   << src3->ne [2 ] << "  nb2="   << src3->nb [2 ] << "  ne3="   << src3->ne [3 ] << "  nb3="   << src3->nb [3 ] << "  offset="   << src3->view_offs  << std::endl;
8829+         }
87998830        std::cerr << " First error: result="   << first_error_result << "  correct="   << first_error_correct  << "  i3="   << first_error[3 ] << "  i2="   << first_error[2 ] << "  i1="   << first_error[1 ] << "  i0="   << first_error[0 ] << std::endl;
88008831        std::cerr << std::endl << " Result:"   << std::endl;
88018832        ggml_vk_print_tensor_area (tensor, tensor_data, first_error[0 ], first_error[1 ], first_error[2 ], first_error[3 ]);
0 commit comments