@@ -1216,37 +1216,34 @@ struct test_get_rows_back : public test_case {
12161216// GGML_OP_SET_ROWS
12171217struct test_set_rows : public test_case {
12181218 const ggml_type type;
1219- const int n; // cols
1220- const int m ; // rows
1219+ const std::array< int64_t , 4 > ne;
1220+ const std::array< int , 2 > nr23 ; // broadcast only dims 2 and 3
12211221 const int r; // rows to set
1222- const int b0; // batch size
1223- const int b1; // batch size
1224- const int bs; // batch size src (for testing broadcast)
12251222 const bool v; // view (non-contiguous src1)
12261223
12271224 std::string vars () override {
1228- return VARS_TO_STR7 (type, n, m , r, b0, bs , v);
1225+ return VARS_TO_STR5 (type, ne, nr23 , r, v);
12291226 }
12301227
1231- test_set_rows (ggml_type type = GGML_TYPE_F32, int n = 10 , int m = 5 , int r = 3 , int b = 1 , int bs = 1 , bool v = false )
1232- : type(type), n(n), m(m), r(r), b0(b), b1( 3 ), bs(bs), v(v) {
1233- GGML_ASSERT (b0 % bs == 0 && " b0 must be a multiple of bs " );
1234- GGML_ASSERT (r <= m && " r must be less than or equal to m " );
1235- }
1228+ test_set_rows (ggml_type type,
1229+ std::array< int64_t , 4 > ne,
1230+ std::array< int , 2 > nr23,
1231+ int r, bool v = false )
1232+ : type(type), ne(ne), nr23(nr23), r(r), v(v) { }
12361233
12371234 ggml_tensor * build_graph (ggml_context * ctx) override {
1238- ggml_tensor * dst = ggml_new_tensor_4d (ctx, type, n, m, b0, b1 );
1235+ ggml_tensor * dst = ggml_new_tensor_4d (ctx, type, ne[ 0 ], ne[ 1 ], ne[ 2 ]*nr23[ 0 ], ne[ 3 ]*nr23[ 1 ] );
12391236 ggml_set_name (dst, " dst" );
12401237
1241- ggml_tensor * src = ggml_new_tensor_4d (ctx, GGML_TYPE_F32, n , r, b0, b1 );
1238+ ggml_tensor * src = ggml_new_tensor_4d (ctx, GGML_TYPE_F32, ne[ 0 ] , r, ne[ 2 ]*nr23[ 0 ], ne[ 3 ]*nr23[ 1 ] );
12421239 ggml_set_name (src, " src" );
12431240
1244- ggml_tensor * row_idxs = ggml_new_tensor_3d (ctx, GGML_TYPE_I64, r, bs, b1 );
1241+ ggml_tensor * row_idxs = ggml_new_tensor_3d (ctx, GGML_TYPE_I64, r, ne[ 2 ], ne[ 3 ] );
12451242 ggml_set_name (row_idxs, " row_idxs" );
12461243
12471244 if (v) {
1248- src = ggml_view_4d (ctx, src, n , r/2 , b0, b1 , src->nb [1 ], src->nb [2 ], src->nb [3 ], 0 );
1249- row_idxs = ggml_view_3d (ctx, row_idxs, r/2 , bs, b1 , row_idxs->nb [1 ], row_idxs->nb [2 ], 0 );
1245+ src = ggml_view_4d (ctx, src, ne[ 0 ] , r/2 , ne[ 2 ]*nr23[ 0 ], ne[ 3 ]*nr23[ 1 ] , src->nb [1 ], src->nb [2 ], src->nb [3 ], 0 );
1246+ row_idxs = ggml_view_3d (ctx, row_idxs, r/2 , ne[ 2 ], ne[ 3 ] , row_idxs->nb [1 ], row_idxs->nb [2 ], 0 );
12501247 ggml_set_name (row_idxs, " view_of_rows" );
12511248 }
12521249
@@ -1268,8 +1265,8 @@ struct test_set_rows : public test_case {
12681265 for (int i2 = 0 ; i2 < t->ne [2 ]; i2++) {
12691266 for (int i1 = 0 ; i1 < t->ne [1 ]; i1++) {
12701267 // generate a shuffled subset of row indices
1271- std::vector<int64_t > data (m );
1272- for (int i = 0 ; i < m ; i++) {
1268+ std::vector<int64_t > data (ne[ 1 ] );
1269+ for (int i = 0 ; i < ne[ 1 ] ; i++) {
12731270 data[i] = i;
12741271 }
12751272 std::shuffle (data.begin (), data.end (), rng);
@@ -4057,11 +4054,19 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
40574054 test_cases.emplace_back (new test_get_rows_back (GGML_TYPE_I32, 256 , 5 , 4 , 1 , v));
40584055 }
40594056
4060- test_cases.emplace_back (new test_set_rows (GGML_TYPE_F32, 1 , 8 , 2 , 1 , 1 , false ));
4057+ test_cases.emplace_back (new test_set_rows (GGML_TYPE_F32, { 1 , 8 , 1 , 3 }, { 1 , 1 }, 2 , false ));
40614058 for (ggml_type type : all_types) {
40624059 for (int b : {1 , 7 }) {
40634060 for (bool v : {false , true }) {
4064- test_cases.emplace_back (new test_set_rows (type, 256 , 5 , 4 , b, 1 , v));
4061+ test_cases.emplace_back (new test_set_rows (type, { 256 , 5 , b, 3 }, { 1 , 1 , }, 1 , v));
4062+ test_cases.emplace_back (new test_set_rows (type, { 256 , 11 , 1 , b }, { 2 , 3 , }, 7 , v));
4063+
4064+ test_cases.emplace_back (new test_set_rows (type, { 3 *ggml_blck_size (type), 3 , b, 1 }, { 2 , 3 , }, 2 , v));
4065+
4066+ if (ggml_blck_size (type) == 1 ) {
4067+ test_cases.emplace_back (new test_set_rows (type, { 31 , 3 , b, 1 }, { 2 , 3 , }, 2 , v));
4068+ test_cases.emplace_back (new test_set_rows (type, { 33 , 5 , 1 , b }, { 2 , 3 , }, 1 , v));
4069+ }
40654070 }
40664071 }
40674072 }
0 commit comments