@@ -3666,7 +3666,7 @@ struct test_flash_attn_ext : public test_case {
36663666
36673667 ggml_tensor * m = nullptr ;
36683668 if (mask) {
3669- m = ggml_new_tensor_4d (ctx, GGML_TYPE_F16, kv, GGML_PAD (nb, GGML_KQ_MASK_PAD), nr23[1 ], 1 );
3669+ m = ggml_new_tensor_4d (ctx, GGML_TYPE_F16, kv, GGML_PAD (nb, GGML_KQ_MASK_PAD), nr23[0 ], nr23[ 1 ] );
36703670 ggml_set_name (m, " m" );
36713671 }
36723672
@@ -4780,7 +4780,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
47804780 test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0-1 , ne1-1 , 1 , 1 }, mask, m_prec, {1 , 1 }, scale, max_bias));
47814781
47824782 if (ne0 <= 32 && ne1 <= 32 ) {
4783- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0, ne1, 1 , 1 }, mask, m_prec, {3 , 1 }, scale, max_bias));
4783+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0, ne1, 1 , 3 }, mask, m_prec, {3 , 1 }, scale, max_bias));
47844784 test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0-1 , ne1-1 , 1 , 1 }, mask, m_prec, {2 , 3 }, scale, max_bias));
47854785 }
47864786 }
0 commit comments