@@ -4588,20 +4588,31 @@ struct test_topk_moe: public test_case {
45884588struct test_sum : public test_case {
45894589 const ggml_type type;
45904590 const std::array<int64_t , 4 > ne;
4591+ const std::array<int64_t , 4 > permute;
4592+ bool _use_permute;
45914593
45924594 std::string vars () override {
4593- return VARS_TO_STR2 (type, ne);
4595+ std::string v = VARS_TO_STR2 (type, ne);
4596+ if (_use_permute) v += " ," + VAR_TO_STR (permute);
4597+ return v;
45944598 }
45954599
45964600 test_sum (ggml_type type = GGML_TYPE_F32,
4597- std::array<int64_t , 4 > ne = {10 , 5 , 4 , 3 })
4598- : type(type), ne(ne) {}
4601+ std::array<int64_t , 4 > ne = {10 , 5 , 4 , 3 },
4602+ std::array<int64_t , 4 > permute = {0 , 0 , 0 , 0 })
4603+ : type(type), ne(ne), permute(permute),
4604+ _use_permute (permute[0 ] + permute[1 ] + permute[2 ] + permute[3 ] > 0 ) {}
45994605
46004606 ggml_tensor * build_graph (ggml_context * ctx) override {
46014607 ggml_tensor * a = ggml_new_tensor (ctx, type, 4 , ne.data ());
46024608 ggml_set_param (a);
46034609 ggml_set_name (a, " a" );
46044610
4611+ if (_use_permute) {
4612+ a = ggml_permute (ctx, a, permute[0 ], permute[1 ], permute[2 ], permute[3 ]);
4613+ ggml_set_name (a, " a_permuted" );
4614+ }
4615+
46054616 ggml_tensor * out = ggml_sum (ctx, a);
46064617 ggml_set_name (out, " out" );
46074618
@@ -6724,6 +6735,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
67246735
67256736 test_cases.emplace_back (new test_sum ());
67266737 test_cases.emplace_back (new test_sum_rows ());
6738+ test_cases.emplace_back (new test_sum (GGML_TYPE_F32, {11 , 5 , 6 , 3 }, {0 , 2 , 1 , 3 })); // row-contiguous but non-contiguous
6739+ test_cases.emplace_back (new test_sum (GGML_TYPE_F32, {11 , 5 , 6 , 3 }, {0 , 3 , 2 , 1 }));
6740+ test_cases.emplace_back (new test_sum (GGML_TYPE_F32, {11 , 5 , 6 , 3 }, {0 , 1 , 3 , 2 }));
67276741 test_cases.emplace_back (new test_sum_rows (GGML_TYPE_F32, { 11 , 5 , 6 , 3 }, true , false ));
67286742 test_cases.emplace_back (new test_sum_rows (GGML_TYPE_F32, { 11 , 5 , 6 , 3 }, false , true ));
67296743 test_cases.emplace_back (new test_sum_rows (GGML_TYPE_F32, { 11 , 5 , 6 , 3 }, true , true ));
@@ -6734,6 +6748,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
67346748 test_cases.emplace_back (new test_sum (GGML_TYPE_F32, { 33 , 1024 , 1 , 1 }));
67356749 test_cases.emplace_back (new test_sum_rows (GGML_TYPE_F32, { 33 , 1024 , 1 , 1 }));
67366750 test_cases.emplace_back (new test_sum (GGML_TYPE_F32, { 33 , 256 , 1 , 1 }));
6751+ test_cases.emplace_back (new test_sum (GGML_TYPE_F32, { 33 , 256 , 1 , 1 }, { 1 , 0 , 2 , 3 })); // sum dst not-contiguous
67376752 test_cases.emplace_back (new test_sum_rows (GGML_TYPE_F32, { 33 , 256 , 1 , 1 }));
67386753 test_cases.emplace_back (new test_mean (GGML_TYPE_F32, { 33 , 256 , 1 , 1 }));
67396754 test_cases.emplace_back (new test_mean (GGML_TYPE_F32, { 32769 , 1 , 1 , 1 }));
0 commit comments