@@ -163,7 +163,6 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
163163 } \
164164 } while (0 )
165165
166- // #define WHISPER_USE_FLASH_FF
167166#define WHISPER_MAX_DECODERS 8
168167#define WHISPER_MAX_NODES 4096
169168
@@ -2104,9 +2103,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
21042103
21052104 struct ggml_tensor * Q =
21062105 ggml_permute (ctx0,
2107- ggml_cpy (ctx0,
2108- Qcur,
2109- ggml_new_tensor_3d (ctx0, GGML_TYPE_F32, n_state_head, n_head, n_ctx)),
2106+ ggml_reshape_3d (ctx0, Qcur, n_state_head, n_head, n_ctx),
21102107 0 , 2 , 1 , 3 );
21112108
21122109 if (wctx.params .flash_attn ) {
@@ -2133,9 +2130,9 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
21332130 } else {
21342131 struct ggml_tensor * K =
21352132 ggml_permute (ctx0,
2136- ggml_cpy (ctx0,
2137- Kcur,
2138- ggml_new_tensor_3d (ctx0, wctx.itype , n_state_head, n_head, n_ctx) ),
2133+ ggml_cast (ctx0,
2134+ ggml_reshape_3d (ctx0, Kcur, n_state_head, n_head, n_ctx) ,
2135+ wctx.itype ),
21392136 0 , 2 , 1 , 3 );
21402137
21412138 // K * Q
@@ -2144,22 +2141,19 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
21442141 struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext (ctx0, KQ, nullptr , KQscale, 0 .0f );
21452142
21462143 struct ggml_tensor * V =
2147- ggml_cpy (ctx0,
2144+ ggml_cast (ctx0,
21482145 ggml_permute (ctx0,
21492146 ggml_reshape_3d (ctx0,
21502147 Vcur,
21512148 n_state_head, n_head, n_ctx),
21522149 1 , 2 , 0 , 3 ),
2153- ggml_new_tensor_3d (ctx0, wctx.itype , n_ctx, n_state_head, n_head)
2154- );
2150+ wctx.itype );
21552151
21562152 struct ggml_tensor * KQV = ggml_mul_mat (ctx0, V, KQ_soft_max);
21572153
21582154 struct ggml_tensor * KQV_merged = ggml_permute (ctx0, KQV, 0 , 2 , 1 , 3 );
21592155
2160- cur = ggml_cpy (ctx0,
2161- KQV_merged,
2162- ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_state, n_ctx));
2156+ cur = ggml_cont_2d (ctx0, KQV_merged, n_state, n_ctx);
21632157 }
21642158 }
21652159
@@ -2189,11 +2183,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
21892183 layer.mlp_ln_b );
21902184 }
21912185
2192- #ifdef WHISPER_USE_FLASH_FF
2193- cur = ggml_flash_ff (ctx0,
2194- ggml_cpy (ctx0, cur, ggml_new_tensor_2d (ctx0, wstate.itype , n_state, n_ctx)),
2195- layer.mlp_0_w , layer.mlp_0_b , layer.mlp_1_w , layer.mlp_1_b );
2196- #else
21972186 // fully connected
21982187 cur = ggml_mul_mat (ctx0,
21992188 layer.mlp_0_w ,
@@ -2210,7 +2199,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
22102199 cur);
22112200
22122201 cur = ggml_add (ctx0, cur, layer.mlp_1_b );
2213- #endif
22142202 }
22152203
22162204 inpL = ggml_add (ctx0, cur, inpFF);
@@ -2586,9 +2574,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
25862574
25872575 struct ggml_tensor * KQV_merged = ggml_permute (ctx0, KQV, 0 , 2 , 1 , 3 );
25882576
2589- cur = ggml_cpy (ctx0,
2590- KQV_merged,
2591- ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_state, n_tokens));
2577+ cur = ggml_cont_2d (ctx0, KQV_merged, n_state, n_tokens);
25922578 }
25932579 }
25942580
@@ -2695,9 +2681,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
26952681
26962682 struct ggml_tensor * KQV_merged = ggml_permute (ctx0, KQV, 0 , 2 , 1 , 3 );
26972683
2698- cur = ggml_cpy (ctx0,
2699- KQV_merged,
2700- ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_state, n_tokens));
2684+ cur = ggml_cont_2d (ctx0, KQV_merged, n_state, n_tokens);
27012685 }
27022686 }
27032687
0 commit comments