Skip to content

Commit f46ddba

Browse files
committed
tests : more consistent implementation + more tests
ggml-ci
1 parent 838e89d commit f46ddba

File tree

1 file changed

+25
-20
lines changed

1 file changed

+25
-20
lines changed

tests/test-backend-ops.cpp

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,37 +1216,34 @@ struct test_get_rows_back : public test_case {
12161216
// GGML_OP_SET_ROWS
12171217
struct test_set_rows : public test_case {
12181218
const ggml_type type;
1219-
const int n; // cols
1220-
const int m; // rows
1219+
const std::array<int64_t, 4> ne;
1220+
const std::array<int, 2> nr23; // broadcast only dims 2 and 3
12211221
const int r; // rows to set
1222-
const int b0; // batch size
1223-
const int b1; // batch size
1224-
const int bs; // batch size src (for testing broadcast)
12251222
const bool v; // view (non-contiguous src1)
12261223

12271224
std::string vars() override {
1228-
return VARS_TO_STR7(type, n, m, r, b0, bs, v);
1225+
return VARS_TO_STR5(type, ne, nr23, r, v);
12291226
}
12301227

1231-
test_set_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int b = 1, int bs = 1, bool v = false)
1232-
: type(type), n(n), m(m), r(r), b0(b), b1(3), bs(bs), v(v) {
1233-
GGML_ASSERT(b0 % bs == 0 && "b0 must be a multiple of bs");
1234-
GGML_ASSERT(r <= m && "r must be less than or equal to m");
1235-
}
1228+
test_set_rows(ggml_type type,
1229+
std::array<int64_t, 4> ne,
1230+
std::array<int, 2> nr23,
1231+
int r, bool v = false)
1232+
: type(type), ne(ne), nr23(nr23), r(r), v(v) {}
12361233

12371234
ggml_tensor * build_graph(ggml_context * ctx) override {
1238-
ggml_tensor * dst = ggml_new_tensor_4d(ctx, type, n, m, b0, b1);
1235+
ggml_tensor * dst = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2]*nr23[0], ne[3]*nr23[1]);
12391236
ggml_set_name(dst, "dst");
12401237

1241-
ggml_tensor * src = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n, r, b0, b1);
1238+
ggml_tensor * src = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], r, ne[2]*nr23[0], ne[3]*nr23[1]);
12421239
ggml_set_name(src, "src");
12431240

1244-
ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, GGML_TYPE_I64, r, bs, b1);
1241+
ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, GGML_TYPE_I64, r, ne[2], ne[3]);
12451242
ggml_set_name(row_idxs, "row_idxs");
12461243

12471244
if (v) {
1248-
src = ggml_view_4d(ctx, src, n, r/2, b0, b1, src->nb[1], src->nb[2], src->nb[3], 0);
1249-
row_idxs = ggml_view_3d(ctx, row_idxs, r/2, bs, b1, row_idxs->nb[1], row_idxs->nb[2], 0);
1245+
src = ggml_view_4d(ctx, src, ne[0], r/2, ne[2]*nr23[0], ne[3]*nr23[1], src->nb[1], src->nb[2], src->nb[3], 0);
1246+
row_idxs = ggml_view_3d(ctx, row_idxs, r/2, ne[2], ne[3], row_idxs->nb[1], row_idxs->nb[2], 0);
12501247
ggml_set_name(row_idxs, "view_of_rows");
12511248
}
12521249

@@ -1268,8 +1265,8 @@ struct test_set_rows : public test_case {
12681265
for (int i2 = 0; i2 < t->ne[2]; i2++) {
12691266
for (int i1 = 0; i1 < t->ne[1]; i1++) {
12701267
// generate a shuffled subset of row indices
1271-
std::vector<int64_t> data(m);
1272-
for (int i = 0; i < m; i++) {
1268+
std::vector<int64_t> data(ne[1]);
1269+
for (int i = 0; i < ne[1]; i++) {
12731270
data[i] = i;
12741271
}
12751272
std::shuffle(data.begin(), data.end(), rng);
@@ -4057,11 +4054,19 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
40574054
test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_I32, 256, 5, 4, 1, v));
40584055
}
40594056

4060-
test_cases.emplace_back(new test_set_rows(GGML_TYPE_F32, 1, 8, 2, 1, 1, false));
4057+
test_cases.emplace_back(new test_set_rows(GGML_TYPE_F32, { 1, 8, 1, 3 }, { 1, 1 }, 2, false));
40614058
for (ggml_type type : all_types) {
40624059
for (int b : {1, 7}) {
40634060
for (bool v : {false, true}) {
4064-
test_cases.emplace_back(new test_set_rows(type, 256, 5, 4, b, 1, v));
4061+
test_cases.emplace_back(new test_set_rows(type, { 256, 5, b, 3 }, { 1, 1, }, 1, v));
4062+
test_cases.emplace_back(new test_set_rows(type, { 256, 11, 1, b }, { 2, 3, }, 7, v));
4063+
4064+
test_cases.emplace_back(new test_set_rows(type, { 3*ggml_blck_size(type), 3, b, 1 }, { 2, 3, }, 2, v));
4065+
4066+
if (ggml_blck_size(type) == 1) {
4067+
test_cases.emplace_back(new test_set_rows(type, { 31, 3, b, 1 }, { 2, 3, }, 2, v));
4068+
test_cases.emplace_back(new test_set_rows(type, { 33, 5, 1, b }, { 2, 3, }, 1, v));
4069+
}
40654070
}
40664071
}
40674072
}

0 commit comments

Comments
 (0)