@@ -2088,30 +2088,59 @@ struct test_ssm_scan : public test_case {
20882088 const ggml_type type;
20892089
20902090 const int64_t d_state;
2091- const int64_t d_inner;
2091+ const int64_t head_dim;
2092+ const int64_t n_head;
2093+ const int64_t n_group;
20922094 const int64_t n_seq_tokens;
20932095 const int64_t n_seqs;
20942096
20952097 std::string vars () override {
2096- return VARS_TO_STR5 (type, d_state, d_inner , n_seq_tokens, n_seqs);
2098+ return VARS_TO_STR7 (type, d_state, head_dim, n_head, n_group , n_seq_tokens, n_seqs);
20972099 }
20982100
20992101 test_ssm_scan (ggml_type type = GGML_TYPE_F32,
2100- int64_t d_state = 32 , int64_t d_inner = 32 , int64_t n_seq_tokens = 32 , int64_t n_seqs = 32 )
2101- : type(type), d_state(d_state), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
2102+ int64_t d_state = 32 ,
2103+ int64_t head_dim = 1 , // non-zero for Mamba-2
2104+ int64_t n_head = 32 ,
2105+ int64_t n_group = 1 ,
2106+ int64_t n_seq_tokens = 32 ,
2107+ int64_t n_seqs = 32 )
2108+ : type(type), d_state(d_state), head_dim(head_dim), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
21022109
21032110 ggml_tensor * build_graph (ggml_context * ctx) override {
2104- ggml_tensor * s = ggml_new_tensor (ctx, type, 4 , std::vector<int64_t >{ d_state, d_inner, n_seqs, 1 }.data ());
2105- ggml_tensor * x = ggml_new_tensor (ctx, type, 4 , std::vector<int64_t >{ d_inner, n_seq_tokens, n_seqs, 1 }.data ());
2106- ggml_tensor * dt = ggml_new_tensor (ctx, type, 4 , std::vector<int64_t >{ d_inner, n_seq_tokens, n_seqs, 1 }.data ());
2107- ggml_tensor * A = ggml_new_tensor (ctx, type, 4 , std::vector<int64_t >{ d_state, d_inner, 1 , 1 }.data ());
2108- ggml_tensor * B = ggml_new_tensor (ctx, type, 4 , std::vector<int64_t >{ d_state, n_seq_tokens, n_seqs, 1 }.data ());
2109- ggml_tensor * C = ggml_new_tensor (ctx, type, 4 , std::vector<int64_t >{ d_state, n_seq_tokens, n_seqs, 1 }.data ());
2110- ggml_tensor * out = ggml_ssm_scan (ctx, s, x, dt, A, B, C);
2111+ ggml_tensor * s = ggml_new_tensor_4d (ctx, type, d_state, head_dim, n_head, n_seqs);
2112+ ggml_tensor * x = ggml_new_tensor_4d (ctx, type, head_dim, n_head, n_seq_tokens, n_seqs);
2113+ ggml_tensor * dt = ggml_new_tensor_3d (ctx, type, n_head, n_seq_tokens, n_seqs);
2114+ ggml_tensor * A = ggml_new_tensor_2d (ctx, type, (head_dim > 1 ) ? 1 : d_state, n_head);
2115+ ggml_tensor * B = ggml_new_tensor_4d (ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
2116+ ggml_tensor * C = ggml_new_tensor_4d (ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
2117+ ggml_tensor * ids = ggml_new_tensor_1d (ctx, GGML_TYPE_I32, n_seqs);
2118+ ggml_tensor * out = ggml_ssm_scan (ctx, s, x, dt, A, B, C, ids);
21112119 return out;
21122120 }
2113- };
21142121
2122+ // similar to test_mul_mat_id
2123+ void initialize_tensors (ggml_context * ctx) override {
2124+ std::random_device rd;
2125+ std::default_random_engine rng (rd ());
2126+ for (ggml_tensor * t = ggml_get_first_tensor (ctx); t != NULL ; t = ggml_get_next_tensor (ctx, t)) {
2127+ if (t->type == GGML_TYPE_I32) {
2128+ if (ggml_is_view_op (t->op )) { continue ; }
2129+ // ids
2130+ for (int64_t r = 0 ; r < ggml_nrows (t); r++) {
2131+ std::vector<int32_t > data (t->ne [0 ]);
2132+ for (int i = 0 ; i < t->ne [0 ]; i++) {
2133+ data[i] = i;
2134+ }
2135+ std::shuffle (data.begin (), data.end (), rng);
2136+ ggml_backend_tensor_set (t, data.data (), r * t->nb [1 ], t->ne [0 ] * sizeof (int32_t ));
2137+ }
2138+ } else {
2139+ init_tensor_uniform (t);
2140+ }
2141+ }
2142+ }
2143+ };
21152144// GGML_OP_RWKV_WKV6
21162145struct test_rwkv_wkv6 : public test_case {
21172146 const ggml_type type;
@@ -3321,7 +3350,7 @@ struct test_upscale_ext : public test_case {
33213350 ggml_tensor * a = ggml_new_tensor (ctx, type, 4 , ne.data ());
33223351 ggml_set_name (a, " a" );
33233352
3324- ggml_tensor * out = ggml_upscale_ext (ctx, a, ne_tgt[0 ], ne_tgt[1 ],ne_tgt[2 ], ne_tgt[3 ], mode);
3353+ ggml_tensor * out = ggml_interpolate (ctx, a, ne_tgt[0 ], ne_tgt[1 ],ne_tgt[2 ], ne_tgt[3 ], mode);
33253354 ggml_set_name (out, " out" );
33263355
33273356 return out;
0 commit comments