@@ -1650,11 +1650,12 @@ struct test_mul_mat : public test_case {
16501650 const int64_t m;
16511651 const int64_t n;
16521652 const int64_t k;
1653- const std::array<int64_t , 2 > bs; // dims 3 and 4
1654- const std::array<int64_t , 2 > nr; // repeat in dims 3 and 4
1653+ const std::array<int64_t , 2 > bs; // dims 3 and 4
1654+ const std::array<int64_t , 2 > nr; // repeat in dims 3 and 4
1655+ const std::array<int64_t , 4 > per; // permutation of dimensions
16551656
16561657 std::string vars () override {
1657- return VARS_TO_STR7 (type_a, type_b, m, n, k, bs, nr);
1658+ return VARS_TO_STR8 (type_a, type_b, m, n, k, bs, nr, per );
16581659 }
16591660
16601661 double max_nmse_err () override {
@@ -1669,17 +1670,44 @@ struct test_mul_mat : public test_case {
16691670 test_mul_mat (ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
16701671 int64_t m = 32 , int64_t n = 32 , int64_t k = 32 ,
16711672 std::array<int64_t , 2 > bs = {10 , 10 },
1672- std::array<int64_t , 2 > nr = {2 , 2 })
1673- : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr) {}
1673+ std::array<int64_t , 2 > nr = {2 , 2 },
1674+ std::array<int64_t , 4 > per = {0 , 1 , 2 , 3 })
1675+ : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per) {}
16741676
16751677 ggml_tensor * build_graph (ggml_context * ctx) override {
16761678 // C^T = A * B^T: (k, m) * (k, n) => (m, n)
1677- ggml_tensor * a = ggml_new_tensor_4d (ctx, type_a, k, m, bs[0 ] , bs[1 ]);
1678- ggml_tensor * b = ggml_new_tensor_4d (ctx, type_b, k, n, bs[0 ]*nr[0 ], bs[1 ]*nr[1 ]);
1679- ggml_set_param (ctx, a);
1680- ggml_set_param (ctx, b);
1681- ggml_set_name (a, " a" );
1682- ggml_set_name (b, " b" );
1679+ ggml_tensor * a;
1680+ ggml_tensor * b;
1681+
1682+ const int npermuted = (per[0 ] != 0 ) + (per[1 ] != 1 ) + (per[2 ] != 2 ) + (per[3 ] != 3 );
1683+ if (npermuted > 0 ) {
1684+ GGML_ASSERT (npermuted == 2 );
1685+ GGML_ASSERT (!ggml_is_quantized (type_a) || per[0 ] == 0 );
1686+ GGML_ASSERT (!ggml_is_quantized (type_b) || per[0 ] == 0 );
1687+
1688+ // Create tensors with the permuted dimensions, then permute them back to the dimensions given by m,n,k.
1689+ const int64_t ne_a[4 ] = {k, m, bs[0 ], bs[1 ]};
1690+ const int64_t ne_b[4 ] = {k, n, bs[0 ]*nr[0 ], bs[1 ]*nr[1 ]};
1691+
1692+ a = ggml_new_tensor_4d (ctx, type_a, ne_a[per[0 ]], ne_a[per[1 ]], ne_a[per[2 ]], ne_a[per[3 ]]);
1693+ b = ggml_new_tensor_4d (ctx, type_b, ne_b[per[0 ]], ne_b[per[1 ]], ne_b[per[2 ]], ne_b[per[3 ]]);
1694+ ggml_set_param (ctx, a);
1695+ ggml_set_param (ctx, b);
1696+ ggml_set_name (a, " a" );
1697+ ggml_set_name (b, " b" );
1698+
1699+ a = ggml_permute (ctx, a, per[0 ], per[1 ], per[2 ], per[3 ]);
1700+ b = ggml_permute (ctx, b, per[0 ], per[1 ], per[2 ], per[3 ]);
1701+ ggml_set_name (a, " a_permuted" );
1702+ ggml_set_name (b, " b_permuted" );
1703+ } else {
1704+ a = ggml_new_tensor_4d (ctx, type_a, k, m, bs[0 ], bs[1 ]);
1705+ b = ggml_new_tensor_4d (ctx, type_b, k, n, bs[0 ]*nr[0 ], bs[1 ]*nr[1 ]);
1706+ ggml_set_param (ctx, a);
1707+ ggml_set_param (ctx, b);
1708+ ggml_set_name (a, " a" );
1709+ ggml_set_name (b, " b" );
1710+ }
16831711
16841712 ggml_tensor * out = ggml_mul_mat (ctx, a, b);
16851713 ggml_set_name (out, " out" );
@@ -3478,13 +3506,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
34783506#if 1
34793507 for (ggml_type type_a : base_types) {
34803508 for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
3481- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , { 1 , 1 }, {1 , 1 }));
3482- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 1 }, {1 , 1 }));
3483- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 1 }, {2 , 1 }));
3484- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 10 }, {1 , 1 }));
3485- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 10 }, {2 , 1 }));
3486- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 10 }, {1 , 2 }));
3487- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 10 }, {2 , 2 }));
3509+ // test cases without permutation
3510+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , { 1 , 1 }, {1 , 1 }));
3511+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 1 }, {1 , 1 }));
3512+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 1 }, {2 , 1 }));
3513+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 10 }, {1 , 1 }));
3514+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 10 }, {2 , 1 }));
3515+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 10 }, {1 , 2 }));
3516+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 10 }, {2 , 2 }));
34883517
34893518 test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , { 1 , 1 }, {1 , 1 }));
34903519 test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {10 , 1 }, {1 , 1 }));
@@ -3493,6 +3522,19 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
34933522 test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {10 , 10 }, {2 , 1 }));
34943523 test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {10 , 10 }, {1 , 2 }));
34953524 test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {10 , 10 }, {2 , 2 }));
3525+
3526+ // test cases with permutation
3527+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {2 , 3 }, {1 , 1 }, {0 , 2 , 1 , 3 }));
3528+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {2 , 3 }, {1 , 1 }, {0 , 1 , 3 , 2 }));
3529+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {2 , 3 }, {1 , 1 }, {0 , 3 , 2 , 1 }));
3530+
3531+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 8 , 256 , {2 , 3 }, {1 , 1 }, {0 , 2 , 1 , 3 }));
3532+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 8 , 256 , {2 , 3 }, {1 , 1 }, {0 , 1 , 3 , 2 }));
3533+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 8 , 256 , {2 , 3 }, {1 , 1 }, {0 , 3 , 2 , 1 }));
3534+
3535+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {2 , 3 }, {1 , 1 }, {0 , 2 , 1 , 3 }));
3536+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {2 , 3 }, {1 , 1 }, {0 , 1 , 3 , 2 }));
3537+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {2 , 3 }, {1 , 1 }, {0 , 3 , 2 , 1 }));
34963538 }
34973539 }
34983540 for (ggml_type type_a : other_types) {
0 commit comments