Skip to content

Commit 90cba7f

Browse files
committed
Add testcase about the broadcast
1 parent f58bbb1 commit 90cba7f

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

tests/test-backend-ops.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2561,7 +2561,7 @@ struct test_rms_norm : public test_case {
25612561
const float eps;
25622562

25632563
std::string vars() override {
2564-
return VARS_TO_STR4(type, ne, v, eps);
2564+
return VARS_TO_STR5(type, ne, v, eps, v);
25652565
}
25662566

25672567
test_rms_norm(ggml_type type = GGML_TYPE_F32,
@@ -2641,6 +2641,7 @@ struct test_rms_norm_mul_add : public test_case {
26412641
const ggml_type type;
26422642
const std::array<int64_t, 4> ne;
26432643
const float eps;
2644+
const bool broadcast;
26442645

26452646
std::string op_desc(ggml_tensor * t) override {
26462647
GGML_UNUSED(t);
@@ -2655,13 +2656,18 @@ struct test_rms_norm_mul_add : public test_case {
26552656

26562657
test_rms_norm_mul_add(ggml_type type = GGML_TYPE_F32,
26572658
std::array<int64_t, 4> ne = {64, 5, 4, 3},
2658-
float eps = 1e-6f)
2659-
: type(type), ne(ne), eps(eps) {}
2659+
float eps = 1e-6f, bool broadcast = false)
2660+
: type(type), ne(ne), eps(eps), broadcast(broadcast) {}
26602661

26612662
ggml_tensor * build_graph(ggml_context * ctx) override {
2662-
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
2663+
std::array<int64_t, 4> broadcast_dims = {ne[0]*2, ne[1]*3, ne[2]*3, ne[3]*4};
2664+
2665+
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, broadcast ? broadcast_dims.data() : ne.data());
26632666
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
26642667
ggml_tensor * c = ggml_new_tensor(ctx, type, 4, ne.data());
2668+
2669+
2670+
26652671
ggml_set_param(a);
26662672
ggml_set_name(a, "a");
26672673
ggml_set_param(b);
@@ -5353,6 +5359,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
53535359
}
53545360
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) {
53555361
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
5362+
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
53565363
}
53575364

53585365
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));

0 commit comments

Comments
 (0)