@@ -1676,33 +1676,37 @@ struct test_mul_mat : public test_case {
16761676
16771677 ggml_tensor * build_graph (ggml_context * ctx) override {
16781678 // C^T = A * B^T: (k, m) * (k, n) => (m, n)
1679- ggml_tensor * a = ggml_new_tensor_4d (ctx, type_a, k, m, bs[0 ] , bs[1 ]);
1680- ggml_tensor * b = ggml_new_tensor_4d (ctx, type_b, k, n, bs[0 ]*nr[0 ], bs[1 ]*nr[1 ]);
1681- ggml_set_param (ctx, a);
1682- ggml_set_param (ctx, b);
1683- ggml_set_name (a, " a" );
1684- ggml_set_name (b, " b" );
1679+ ggml_tensor * a;
1680+ ggml_tensor * b;
16851681
1686- // If the permutation is not {0, 1, 2, 3}, replace a and b with views that have the same data in a different order.
1687- // This test only works correctly if exactly 2 indices != 0 are swapped.
1688- if (per[0 ] != 0 || per[1 ] != 1 || per[2 ] != 2 || per[3 ] != 3 ) {
1689- GGML_ASSERT (per[0 ] == 0 );
1690- const size_t rsa = ggml_row_size (a->type , a->ne [0 ]);
1691- const size_t rsb = ggml_row_size (b->type , b->ne [0 ]);
1692- size_t nba[GGML_MAX_DIMS] = {ggml_type_size (a->type ), rsa, rsa, rsa};
1693- size_t nbb[GGML_MAX_DIMS] = {ggml_type_size (b->type ), rsb, rsb, rsb};
1694- for (int64_t i = 1 ; i < GGML_MAX_DIMS; ++i) {
1695- for (int64_t j = 1 ; j < per[i]; ++j) {
1696- nba[i] *= a->ne [per[j]];
1697- nbb[i] *= b->ne [per[j]];
1698- }
1699- }
1700- a = ggml_view_4d (ctx, a, a->ne [0 ], a->ne [1 ], a->ne [2 ], a->ne [3 ], nba[1 ], nba[2 ], nba[3 ], /* offset =*/ 0 );
1701- b = ggml_view_4d (ctx, b, b->ne [0 ], b->ne [1 ], b->ne [2 ], b->ne [3 ], nbb[1 ], nbb[2 ], nbb[3 ], /* offset =*/ 0 );
1702- GGML_ASSERT (ggml_nbytes (a) == ggml_nbytes (a->src [0 ]));
1703- GGML_ASSERT (ggml_nbytes (b) == ggml_nbytes (b->src [0 ]));
1682+ const int npermuted = (per[0 ] != 0 ) + (per[1 ] != 1 ) + (per[2 ] != 2 ) + (per[3 ] != 3 );
1683+ if (npermuted > 0 ) {
1684+ GGML_ASSERT (npermuted == 2 );
1685+ GGML_ASSERT (!ggml_is_quantized (type_a) || per[0 ] == 0 );
1686+ GGML_ASSERT (!ggml_is_quantized (type_b) || per[0 ] == 0 );
1687+
1688+ // Create tensors with the permuted dimensions, then permute them back to the dimensions given by m,n,k.
1689+ const int64_t ne_a[4 ] = {k, m, bs[0 ], bs[1 ]};
1690+ const int64_t ne_b[4 ] = {k, n, bs[0 ]*nr[0 ], bs[1 ]*nr[1 ]};
1691+
1692+ a = ggml_new_tensor_4d (ctx, type_a, ne_a[per[0 ]], ne_a[per[1 ]], ne_a[per[2 ]], ne_a[per[3 ]]);
1693+ b = ggml_new_tensor_4d (ctx, type_b, ne_b[per[0 ]], ne_b[per[1 ]], ne_b[per[2 ]], ne_b[per[3 ]]);
1694+ ggml_set_param (ctx, a);
1695+ ggml_set_param (ctx, b);
1696+ ggml_set_name (a, " a" );
1697+ ggml_set_name (b, " b" );
1698+
1699+ a = ggml_permute (ctx, a, per[0 ], per[1 ], per[2 ], per[3 ]);
1700+ b = ggml_permute (ctx, b, per[0 ], per[1 ], per[2 ], per[3 ]);
17041701 ggml_set_name (a, " a_permuted" );
17051702 ggml_set_name (b, " b_permuted" );
1703+ } else {
1704+ a = ggml_new_tensor_4d (ctx, type_a, k, m, bs[0 ], bs[1 ]);
1705+ b = ggml_new_tensor_4d (ctx, type_b, k, n, bs[0 ]*nr[0 ], bs[1 ]*nr[1 ]);
1706+ ggml_set_param (ctx, a);
1707+ ggml_set_param (ctx, b);
1708+ ggml_set_name (a, " a" );
1709+ ggml_set_name (b, " b" );
17061710 }
17071711
17081712 ggml_tensor * out = ggml_mul_mat (ctx, a, b);
0 commit comments