@@ -2561,7 +2561,7 @@ struct test_rms_norm : public test_case {
25612561 const float eps;
25622562
25632563 std::string vars () override {
2564- return VARS_TO_STR4 (type, ne, v, eps);
2564+ return VARS_TO_STR5 (type, ne, v, eps, v );
25652565 }
25662566
25672567 test_rms_norm (ggml_type type = GGML_TYPE_F32,
@@ -2641,6 +2641,7 @@ struct test_rms_norm_mul_add : public test_case {
26412641 const ggml_type type;
26422642 const std::array<int64_t , 4 > ne;
26432643 const float eps;
2644+ const bool broadcast;
26442645
26452646 std::string op_desc (ggml_tensor * t) override {
26462647 GGML_UNUSED (t);
@@ -2655,13 +2656,18 @@ struct test_rms_norm_mul_add : public test_case {
26552656
26562657 test_rms_norm_mul_add (ggml_type type = GGML_TYPE_F32,
26572658 std::array<int64_t , 4 > ne = {64 , 5 , 4 , 3 },
2658- float eps = 1e-6f )
2659- : type(type), ne(ne), eps(eps) {}
2659+ float eps = 1e-6f , bool broadcast = false )
2660+ : type(type), ne(ne), eps(eps), broadcast(broadcast) {}
26602661
26612662 ggml_tensor * build_graph (ggml_context * ctx) override {
2662- ggml_tensor * a = ggml_new_tensor (ctx, type, 4 , ne.data ());
2663+ std::array<int64_t , 4 > broadcast_dims = {ne[0 ]*2 , ne[1 ]*3 , ne[2 ]*3 , ne[3 ]*4 };
2664+
2665+ ggml_tensor * a = ggml_new_tensor (ctx, type, 4 , broadcast ? broadcast_dims.data () : ne.data ());
26632666 ggml_tensor * b = ggml_new_tensor (ctx, type, 4 , ne.data ());
26642667 ggml_tensor * c = ggml_new_tensor (ctx, type, 4 , ne.data ());
2668+
2669+
2670+
26652671 ggml_set_param (a);
26662672 ggml_set_name (a, " a" );
26672673 ggml_set_param (b);
@@ -5353,6 +5359,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
53535359 }
53545360 for (float eps : {0 .0f , 1e-6f , 1e-4f , 1e-1f , 1 .0f }) {
53555361 test_cases.emplace_back (new test_rms_norm_mul_add (GGML_TYPE_F32, {64 , 5 , 4 , 3 }, eps));
5362+ test_cases.emplace_back (new test_rms_norm_mul_add (GGML_TYPE_F32, {64 , 5 , 4 , 3 }, eps, true ));
53565363 }
53575364
53585365 test_cases.emplace_back (new test_l2_norm (GGML_TYPE_F32, {64 , 5 , 4 , 3 }, 1e-12f ));
0 commit comments