@@ -4258,26 +4258,32 @@ struct test_flash_attn_ext : public test_case {
42584258 const int64_t hsk_padded = GGML_PAD (hsk, ggml_blck_size (type_KV));
42594259 const int64_t hsv_padded = GGML_PAD (hsv, ggml_blck_size (type_KV));
42604260
4261- auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) -> ggml_tensor * {
4261+ auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, bool is_view ) -> ggml_tensor * {
42624262 int64_t ne[4 ] = {ne0, ne1, ne2, ne3};
42634263 int64_t ne_perm[4 ];
42644264 for (int i = 0 ; i < 4 ; ++i) {
42654265 ne_perm[permute[i]] = ne[i];
42664266 }
4267- ggml_tensor * t = ggml_new_tensor_4d (ctx, type, ne_perm[0 ], ne_perm[1 ], ne_perm[2 ], ne_perm[3 ]);
4267+ ggml_tensor * t;
4268+ if (is_view) {
4269+ ggml_tensor * t0 = ggml_new_tensor_4d (ctx, type, ne_perm[0 ], 2 *ne_perm[1 ], ne_perm[2 ], ne_perm[3 ]);
4270+ t = ggml_view_4d (ctx, t0, ne_perm[0 ], ne_perm[1 ], ne_perm[2 ], ne_perm[3 ], t0->nb [1 ], t0->nb [2 ], t0->nb [3 ], 0 );
4271+ } else {
4272+ t = ggml_new_tensor_4d (ctx, type, ne_perm[0 ], ne_perm[1 ], ne_perm[2 ], ne_perm[3 ]);
4273+ }
42684274 if (permute != std::array<int32_t , 4 >{0 , 1 , 2 , 3 }) {
42694275 t = ggml_permute (ctx, t, permute[0 ], permute[1 ], permute[2 ], permute[3 ]);
42704276 }
42714277 return t;
42724278 };
42734279
4274- ggml_tensor * q = create_permuted (GGML_TYPE_F32, hsk_padded, nb, nh*nr23[0 ], nr23[1 ]);
4280+ ggml_tensor * q = create_permuted (GGML_TYPE_F32, hsk_padded, nb, nh*nr23[0 ], nr23[1 ], false );
42754281 ggml_set_name (q, " q" );
42764282
4277- ggml_tensor * k = create_permuted (type_KV, hsk_padded, kv, nh, nr23[1 ]);
4283+ ggml_tensor * k = create_permuted (type_KV, hsk_padded, kv, nh, nr23[1 ], true ); // the K tensor is usually a view of the K cache
42784284 ggml_set_name (k, " k" );
42794285
4280- ggml_tensor * v = create_permuted (type_KV, hsv_padded, kv, nh, nr23[1 ]);
4286+ ggml_tensor * v = create_permuted (type_KV, hsv_padded, kv, nh, nr23[1 ], true ); // the V tensor is usually a view of the V cache
42814287 ggml_set_name (v, " v" );
42824288
42834289 ggml_tensor * m = nullptr ;
0 commit comments