@@ -1543,6 +1543,36 @@ struct test_ssm_scan : public test_case {
15431543 }
15441544};
15451545
1546+ // GGML_OP_RWKV_WKV
1547+ struct test_rwkv_wkv : public test_case {
1548+ const ggml_type type;
1549+
1550+ const int64_t head_count;
1551+ const int64_t head_size;
1552+ const int64_t n_seq_tokens;
1553+ const int64_t n_seqs;
1554+
1555+ std::string vars () override {
1556+ return VARS_TO_STR5 (type, head_count, head_size, n_seq_tokens, n_seqs);
1557+ }
1558+
1559+ test_rwkv_wkv (ggml_type type = GGML_TYPE_F32,
1560+ int64_t head_count = 32 , int64_t head_size = 64 , int64_t n_seq_tokens = 32 , int64_t n_seqs = 32 )
1561+ : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
1562+
1563+ ggml_tensor * build_graph (ggml_context * ctx) override {
1564+ const int64_t n_tokens = n_seq_tokens * n_seqs;
1565+ ggml_tensor * r = ggml_new_tensor (ctx, type, 4 , std::vector<int64_t >{ 1 , head_size, head_count, n_tokens }.data ());
1566+ ggml_tensor * k = ggml_new_tensor (ctx, type, 4 , std::vector<int64_t >{ head_size, 1 , head_count, n_tokens }.data ());
1567+ ggml_tensor * v = ggml_new_tensor (ctx, type, 4 , std::vector<int64_t >{ 1 , head_size, head_count, n_tokens }.data ());
1568+ ggml_tensor * tf = ggml_new_tensor (ctx, type, 2 , std::vector<int64_t >{ head_size, head_count }.data ());
1569+ ggml_tensor * td = ggml_new_tensor (ctx, type, 4 , std::vector<int64_t >{ 1 , head_size, head_count, n_tokens }.data ());
1570+ ggml_tensor * s = ggml_new_tensor (ctx, type, 2 , std::vector<int64_t >{ head_size * head_size * head_count, n_seqs }.data ());
1571+ ggml_tensor * out = ggml_rwkv_wkv (ctx, k, v, r, tf, td, s);
1572+ return out;
1573+ }
1574+ };
1575+
15461576// GGML_OP_MUL_MAT
15471577struct test_mul_mat : public test_case {
15481578 const ggml_type type_a;
@@ -3337,6 +3367,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
33373367
33383368 test_cases.emplace_back (new test_ssm_scan (GGML_TYPE_F32, 16 , 1024 , 32 , 4 ));
33393369
3370+ test_cases.emplace_back (new test_rwkv_wkv (GGML_TYPE_F32, 32 , 64 , 1 , 1 ));
3371+ test_cases.emplace_back (new test_rwkv_wkv (GGML_TYPE_F32, 32 , 64 , 32 , 1 ));
3372+ test_cases.emplace_back (new test_rwkv_wkv (GGML_TYPE_F32, 32 , 64 , 32 , 4 ));
3373+ test_cases.emplace_back (new test_rwkv_wkv (GGML_TYPE_F32, 32 , 64 , 128 , 4 ));
3374+
33403375#if 1
33413376 for (ggml_type type_a : base_types) {
33423377 for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
@@ -3564,7 +3599,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
35643599 if (hs != 128 && logit_softcap != 0 .0f ) continue ;
35653600 for (int nh : { 32 , }) {
35663601 for (int kv : { 512 , 1024 , }) {
3567- for (int nb : { 1 , 2 , 4 , 8 , }) {
3602+ for (int nb : { 1 , 3 , 32 , 35 , }) {
35683603 for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
35693604 test_cases.emplace_back (new test_flash_attn_ext (hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV));
35703605 }
0 commit comments