Skip to content

Commit 77ae46f

Browse files
committed
add test for non-cont inplace rope
1 parent 530ef06 commit 77ae46f

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

tests/test-backend-ops.cpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3768,6 +3768,7 @@ struct test_rope : public test_case {
37683768
const ggml_type type;
37693769
const std::array<int64_t, 4> ne_a;
37703770
int n_dims;
3771+
int offset;
37713772
int mode;
37723773
int n_ctx; // used to generate positions
37733774
float fs; // freq_scale
@@ -3779,16 +3780,17 @@ struct test_rope : public test_case {
37793780

37803781
std::string vars() override {
37813782
// forward can be inferred from the op, does not need to be printed
3782-
return VARS_TO_STR10(type, ne_a, n_dims, mode, n_ctx, fs, ef, af, ff, v);
3783+
return VARS_TO_STR11(type, ne_a, n_dims, offset, mode, n_ctx, fs, ef, af, ff, v);
37833784
}
37843785

37853786
test_rope(ggml_type type = GGML_TYPE_F32,
37863787
std::array<int64_t, 4> ne_a = {10, 5, 3, 1},
3787-
int n_dims = 10, int mode = 0, int n_ctx = 512, float fs = 1.0f,
3788+
int n_dims = 10, int offset = 0, int mode = 0, int n_ctx = 512, float fs = 1.0f,
37883789
float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0, bool forward = true)
3789-
: type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v), forward(forward) {}
3790+
: type(type), ne_a(ne_a), n_dims(n_dims), offset(offset), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v), forward(forward) {}
37903791

37913792
ggml_tensor * build_graph(ggml_context * ctx) override {
3793+
bool inplace = false;
37923794
ggml_tensor * a;
37933795
if (v & 1) {
37943796
auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3;
@@ -3808,6 +3810,14 @@ struct test_rope : public test_case {
38083810
ggml_set_name(a, "a");
38093811
}
38103812

3813+
if (offset > 0) {
3814+
inplace = true;
3815+
a = ggml_view_3d(ctx, a, a->ne[0] - offset, a->ne[1], a->ne[2],
3816+
ggml_row_size(a->type, a->ne[0]),
3817+
ggml_row_size(a->type, a->ne[0]*a->ne[1]),
3818+
offset * ggml_element_size(a));
3819+
}
3820+
38113821
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
38123822
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
38133823

@@ -3846,12 +3856,12 @@ struct test_rope : public test_case {
38463856
}
38473857
} else {
38483858
if (forward) {
3849-
out = ggml_rope_ext (ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
3859+
out = inplace
3860+
? ggml_rope_ext_inplace(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f)
3861+
: ggml_rope_ext (ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
38503862
} else {
3851-
out = ggml_rope_ext_back(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
3863+
out = ggml_rope_ext_back (ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
38523864
}
3853-
3854-
// TODO: add test with a non-contiguous view as input ; this case is needed for build_rope_2d in clip.cpp
38553865
}
38563866
ggml_set_name(out, "out");
38573867

0 commit comments

Comments
 (0)