@@ -1650,11 +1650,12 @@ struct test_mul_mat : public test_case {
16501650    const  int64_t  m;
16511651    const  int64_t  n;
16521652    const  int64_t  k;
1653-     const  std::array<int64_t , 2 > bs; //  dims 3 and 4
1654-     const  std::array<int64_t , 2 > nr; //  repeat in dims 3 and 4
1653+     const  std::array<int64_t , 2 > bs;  //  dims 3 and 4
1654+     const  std::array<int64_t , 2 > nr;  //  repeat in dims 3 and 4
1655+     const  std::array<int64_t , 4 > per; //  permutation of dimensions
16551656
16561657    std::string vars () override  {
1657-         return  VARS_TO_STR7 (type_a, type_b, m, n, k, bs, nr);
1658+         return  VARS_TO_STR8 (type_a, type_b, m, n, k, bs, nr, per );
16581659    }
16591660
16601661    double  max_nmse_err () override  {
@@ -1669,8 +1670,9 @@ struct test_mul_mat : public test_case {
16691670    test_mul_mat (ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
16701671            int64_t  m = 32 , int64_t  n = 32 , int64_t  k = 32 ,
16711672            std::array<int64_t , 2 > bs = {10 , 10 },
1672-             std::array<int64_t , 2 > nr = {2 , 2 })
1673-         : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr) {}
1673+             std::array<int64_t , 2 > nr = {2 , 2 },
1674+             std::array<int64_t , 4 > per = {0 , 1 , 2 , 3 })
1675+         : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per) {}
16741676
16751677    ggml_tensor * build_graph (ggml_context * ctx) override  {
16761678        //  C^T = A * B^T: (k, m) * (k, n) => (m, n)
@@ -1681,6 +1683,28 @@ struct test_mul_mat : public test_case {
16811683        ggml_set_name (a, " a"  );
16821684        ggml_set_name (b, " b"  );
16831685
1686+         //  If the permutation is not {0, 1, 2, 3}, replace a and b with views that have the same data in a different order.
1687+         //  This test only works correctly if exactly 2 indices != 0 are swapped.
1688+         if  (per[0 ] != 0  || per[1 ] != 1  || per[2 ] != 2  || per[3 ] != 3 ) {
1689+             GGML_ASSERT (per[0 ] == 0 );
1690+             const  size_t  rsa = ggml_row_size (a->type , a->ne [0 ]);
1691+             const  size_t  rsb = ggml_row_size (b->type , b->ne [0 ]);
1692+             size_t  nba[GGML_MAX_DIMS] = {ggml_type_size (a->type ), rsa, rsa, rsa};
1693+             size_t  nbb[GGML_MAX_DIMS] = {ggml_type_size (b->type ), rsb, rsb, rsb};
1694+             for  (int64_t  i = 1 ; i < GGML_MAX_DIMS; ++i) {
1695+                 for  (int64_t  j = 1 ; j < per[i]; ++j) {
1696+                     nba[i] *= a->ne [per[j]];
1697+                     nbb[i] *= b->ne [per[j]];
1698+                 }
1699+             }
1700+             a = ggml_view_4d (ctx, a, a->ne [0 ], a->ne [1 ], a->ne [2 ], a->ne [3 ], nba[1 ], nba[2 ], nba[3 ], /* offset =*/   0 );
1701+             b = ggml_view_4d (ctx, b, b->ne [0 ], b->ne [1 ], b->ne [2 ], b->ne [3 ], nbb[1 ], nbb[2 ], nbb[3 ], /* offset =*/   0 );
1702+             GGML_ASSERT (ggml_nbytes (a) == ggml_nbytes (a->src [0 ]));
1703+             GGML_ASSERT (ggml_nbytes (b) == ggml_nbytes (b->src [0 ]));
1704+             ggml_set_name (a, " a_permuted"  );
1705+             ggml_set_name (b, " b_permuted"  );
1706+         }
1707+ 
16841708        ggml_tensor * out = ggml_mul_mat (ctx, a, b);
16851709        ggml_set_name (out, " out"  );
16861710
@@ -3442,13 +3466,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
34423466#if  1 
34433467    for  (ggml_type type_a : base_types) {
34443468        for  (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
3445-             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 , 1 , 256 , { 1 ,  1 }, {1 , 1 }));
3446-             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 ,  1 }, {1 , 1 }));
3447-             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 ,  1 }, {2 , 1 }));
3448-             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 10 }, {1 , 1 }));
3449-             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 10 }, {2 , 1 }));
3450-             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 10 }, {1 , 2 }));
3451-             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 10 }, {2 , 2 }));
3469+             //  test cases without permutation
3470+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 ,  1 , 256 , { 1 ,  1 }, {1 , 1 }));
3471+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 ,  1 , 256 , {10 ,  1 }, {1 , 1 }));
3472+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 ,  1 , 256 , {10 ,  1 }, {2 , 1 }));
3473+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 ,  1 , 256 , {10 , 10 }, {1 , 1 }));
3474+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 ,  1 , 256 , {10 , 10 }, {2 , 1 }));
3475+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 ,  1 , 256 , {10 , 10 }, {1 , 2 }));
3476+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 ,  1 , 256 , {10 , 10 }, {2 , 2 }));
34523477
34533478            test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 , 16 , 256 , { 1 ,  1 }, {1 , 1 }));
34543479            test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 , 16 , 256 , {10 ,  1 }, {1 , 1 }));
@@ -3457,6 +3482,19 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
34573482            test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 , 16 , 256 , {10 , 10 }, {2 , 1 }));
34583483            test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 , 16 , 256 , {10 , 10 }, {1 , 2 }));
34593484            test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 , 16 , 256 , {10 , 10 }, {2 , 2 }));
3485+ 
3486+             //  test cases with permutation
3487+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 ,  1 , 256 , {2 , 3 }, {1 , 1 }, {0 , 2 , 1 , 3 }));
3488+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 ,  1 , 256 , {2 , 3 }, {1 , 1 }, {0 , 1 , 3 , 2 }));
3489+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 ,  1 , 256 , {2 , 3 }, {1 , 1 }, {0 , 3 , 2 , 1 }));
3490+ 
3491+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 ,  8 , 256 , {2 , 3 }, {1 , 1 }, {0 , 2 , 1 , 3 }));
3492+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 ,  8 , 256 , {2 , 3 }, {1 , 1 }, {0 , 1 , 3 , 2 }));
3493+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 ,  8 , 256 , {2 , 3 }, {1 , 1 }, {0 , 3 , 2 , 1 }));
3494+ 
3495+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 , 16 , 256 , {2 , 3 }, {1 , 1 }, {0 , 2 , 1 , 3 }));
3496+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 , 16 , 256 , {2 , 3 }, {1 , 1 }, {0 , 1 , 3 , 2 }));
3497+             test_cases.emplace_back (new  test_mul_mat (type_a, type_b, 16 , 16 , 256 , {2 , 3 }, {1 , 1 }, {0 , 3 , 2 , 1 }));
34603498        }
34613499    }
34623500    for  (ggml_type type_a : other_types) {
0 commit comments