@@ -1459,11 +1459,13 @@ struct test_cpy : public test_case {
14591459 const ggml_type type_src;
14601460 const ggml_type type_dst;
14611461 const std::array<int64_t , 4 > ne;
1462- const std::array<int64_t , 4 > permute;
1462+ const std::array<int64_t , 4 > permute_src;
1463+ const std::array<int64_t , 4 > permute_dst;
14631464 bool _src_use_permute;
1465+ bool _dst_use_permute;
14641466
14651467 std::string vars () override {
1466- return VARS_TO_STR4 (type_src, type_dst, ne, permute );
1468+ return VARS_TO_STR5 (type_src, type_dst, ne, permute_src, permute_dst );
14671469 }
14681470
14691471 double max_nmse_err () override {
@@ -1476,23 +1478,30 @@ struct test_cpy : public test_case {
14761478
14771479 test_cpy (ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
14781480 std::array<int64_t , 4 > ne = {10 , 10 , 10 , 1 },
1479- std::array<int64_t , 4 > permute = {0 , 0 , 0 , 0 })
1480- : type_src(type_src), type_dst(type_dst), ne(ne), permute(permute),
1481- _src_use_permute (permute[0 ] + permute[1 ] + permute[2 ] + permute[3 ] > 0 ) {}
1481+ std::array<int64_t , 4 > permute_src = {0 , 0 , 0 , 0 },
1482+ std::array<int64_t , 4 > permute_dst = {0 , 0 , 0 , 0 })
1483+ : type_src(type_src), type_dst(type_dst), ne(ne), permute_src(permute_src), permute_dst(permute_dst),
1484+ _src_use_permute (permute_src[0 ] + permute_src[1 ] + permute_src[2 ] + permute_src[3 ] > 0 ),
1485+ _dst_use_permute(permute_dst[0 ] + permute_dst[1 ] + permute_dst[2 ] + permute_dst[3 ] > 0 ) {}
14821486
14831487 ggml_tensor * build_graph (ggml_context * ctx) override {
14841488 ggml_tensor * src = ggml_new_tensor (ctx, type_src, 4 , ne.data ());
14851489 ggml_set_param (ctx, src);
14861490 ggml_set_name (src, " src" );
14871491
14881492 if (_src_use_permute) {
1489- src = ggml_permute (ctx, src, permute [0 ], permute [1 ], permute [2 ], permute [3 ]);
1493+ src = ggml_permute (ctx, src, permute_src [0 ], permute_src [1 ], permute_src [2 ], permute_src [3 ]);
14901494 ggml_set_name (src, " src_permuted" );
14911495 }
14921496
1493- ggml_tensor* dst = ggml_new_tensor (ctx, type_dst, 4 , src->ne );
1497+ ggml_tensor * dst = ggml_new_tensor (ctx, type_dst, 4 , src->ne );
14941498 ggml_set_name (dst, " dst" );
14951499
1500+ if (_dst_use_permute) {
1501+ dst = ggml_permute (ctx, dst, permute_dst[0 ], permute_dst[1 ], permute_dst[2 ], permute_dst[3 ]);
1502+ ggml_set_name (dst, " dst_permuted" );
1503+ }
1504+
14961505 ggml_tensor * out = ggml_cpy (ctx, src, dst);
14971506 ggml_set_name (out, " out" );
14981507
@@ -3930,13 +3939,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
39303939 }
39313940
39323941 // same-type copy
3933- for (int nb = 1 ; nb < 4 ; ++nb) {
3934- for (ggml_type type : all_types) {
3935- const auto neb = ggml_blck_size (type);
3942+ for (ggml_type type : all_types) {
3943+ const auto nk = ggml_blck_size (type);
39363944
3937- test_cases.emplace_back (new test_cpy (type, type, {nb*neb, 2 , 3 , 4 }, {0 , 1 , 2 , 3 }));
3938- test_cases.emplace_back (new test_cpy (type, type, {nb*neb, 2 , 3 , 4 }, {0 , 2 , 1 , 3 }));
3939- test_cases.emplace_back (new test_cpy (type, type, {nb*neb, 2 , 3 , 4 }, {0 , 3 , 1 , 2 }));
3945+ for (int k = 1 ; k < 4 ; ++k) {
3946+ test_cases.emplace_back (new test_cpy (type, type, {k*nk, 2 , 3 , 4 }));
3947+ test_cases.emplace_back (new test_cpy (type, type, {k*nk, 2 , 3 , 4 }, {0 , 2 , 1 , 3 }));
3948+ test_cases.emplace_back (new test_cpy (type, type, {k*nk, 2 , 3 , 4 }, {0 , 3 , 1 , 2 }, {0 , 2 , 1 , 3 }));
39403949 }
39413950 }
39423951
0 commit comments