@@ -2353,9 +2353,12 @@ struct test_bin_bcast : public test_case {
23532353 const ggml_type type;
23542354 const std::array<int64_t , 4 > ne;
23552355 const std::array<int , 4 > nr;
2356+ int nf; // number of fused ops, nf == 1 -> single op (no fusion)
2357+
2358+ bool run_whole_graph () override { return true ; }
23562359
23572360 std::string vars () override {
2358- return VARS_TO_STR3 (type, ne, nr);
2361+ return VARS_TO_STR4 (type, ne, nr, nf );
23592362 }
23602363
23612364 size_t op_size (ggml_tensor * t) override {
@@ -2364,24 +2367,35 @@ struct test_bin_bcast : public test_case {
23642367
23652368 test_bin_bcast (op_t op, ggml_type type = GGML_TYPE_F32,
23662369 std::array<int64_t , 4 > ne = {10 , 10 , 1 , 1 },
2367- std::array<int , 4 > nr = {1 , 2 , 1 , 1 })
2368- : op(op), type(type), ne(ne), nr(nr) {}
2370+ std::array<int , 4 > nr = {1 , 2 , 1 , 1 },
2371+ int nf = 1 )
2372+ : op(op), type(type), ne(ne), nr(nr), nf(nf) {}
23692373
23702374 ggml_tensor * build_graph (ggml_context * ctx) override {
2375+ GGML_ASSERT (nf <= 8 );
2376+
23712377 ggml_tensor * a = ggml_new_tensor_4d (ctx, type, ne[0 ]*nr[0 ], ne[1 ]*nr[1 ], ne[2 ]*nr[2 ], ne[3 ]*nr[3 ]);
23722378 ggml_set_name (a, " a" );
23732379
2374- ggml_tensor * b = ggml_new_tensor (ctx, type, 4 , ne.data ());
2375- ggml_set_name (b, " b" );
2380+ ggml_tensor * b[8 ];
2381+ for (int i = 0 ; i < nf; ++i) {
2382+ b[i] = ggml_new_tensor (ctx, type, 4 , ne.data ());
2383+ ggml_set_name (b[i], (std::string (" b" ) + std::to_string (i)).c_str ());
2384+ }
23762385
23772386 // The backward pass supports broadcasting only for GGML_ADD:
2378- const bool grad_supported = op == ggml_add || ggml_are_same_shape (a, b) ;
2387+ const bool grad_supported = op == ggml_add && ggml_are_same_shape (a, b[ 0 ]) && nf == 1 ;
23792388 if (grad_supported) {
23802389 ggml_set_param (a);
2381- ggml_set_param (b);
2390+ ggml_set_param (b[0 ]);
2391+ }
2392+
2393+ ggml_tensor * out = a;
2394+
2395+ for (int i = 0 ; i < nf; ++i) {
2396+ out = op (ctx, out, b[i]);
23822397 }
23832398
2384- ggml_tensor * out = op (ctx, a, b);
23852399 ggml_set_name (out, " out" );
23862400
23872401 return out;
@@ -5151,6 +5165,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
51515165 // add_test_bin_bcast(type, {3, 3, 2560, 1280}, {2, 1, 1, 1});
51525166 }
51535167
5168+ // fusion
5169+ test_cases.emplace_back (new test_bin_bcast (ggml_add, GGML_TYPE_F32, {10 , 5 , 4 , 3 }, {2 , 1 , 1 , 1 }, 2 ));
5170+ test_cases.emplace_back (new test_bin_bcast (ggml_add, GGML_TYPE_F32, {16 , 5 , 4 , 3 }, {1 , 2 , 1 , 1 }, 3 ));
5171+ test_cases.emplace_back (new test_bin_bcast (ggml_add, GGML_TYPE_F32, {10 , 5 , 4 , 3 }, {1 , 1 , 2 , 1 }, 4 ));
5172+ test_cases.emplace_back (new test_bin_bcast (ggml_add, GGML_TYPE_F32, {16 , 5 , 4 , 3 }, {1 , 1 , 1 , 2 }, 5 ));
5173+ test_cases.emplace_back (new test_bin_bcast (ggml_add, GGML_TYPE_F32, {10 , 5 , 4 , 3 }, {1 , 1 , 2 , 2 }, 6 ));
5174+ test_cases.emplace_back (new test_bin_bcast (ggml_add, GGML_TYPE_F32, {10 , 5 , 4 , 3 }, {1 , 2 , 2 , 2 }, 7 ));
5175+ test_cases.emplace_back (new test_bin_bcast (ggml_add, GGML_TYPE_F32, {16 , 5 , 4 , 3 }, {2 , 2 , 2 , 2 }, 8 ));
5176+
51545177 test_cases.emplace_back (new test_add1 ());
51555178 test_cases.emplace_back (new test_scale ());
51565179 test_cases.emplace_back (new test_scale (GGML_TYPE_F32, {10 , 10 , 10 , 10 }, 2 .0f , 1 .0f ));
0 commit comments