Skip to content

Commit 1874968

Browse files
committed
whisper : reduce memory overhead from unused input tensors
1 parent 3be0c57 commit 1874968

File tree

1 file changed

+9
-25
lines changed

1 file changed

+9
-25
lines changed

src/whisper.cpp

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,6 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
163163
} \
164164
} while (0)
165165

166-
//#define WHISPER_USE_FLASH_FF
167166
#define WHISPER_MAX_DECODERS 8
168167
#define WHISPER_MAX_NODES 4096
169168

@@ -2104,9 +2103,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
21042103

21052104
struct ggml_tensor * Q =
21062105
ggml_permute(ctx0,
2107-
ggml_cpy(ctx0,
2108-
Qcur,
2109-
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state_head, n_head, n_ctx)),
2106+
ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_ctx),
21102107
0, 2, 1, 3);
21112108

21122109
if (wctx.params.flash_attn) {
@@ -2133,9 +2130,9 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
21332130
} else {
21342131
struct ggml_tensor * K =
21352132
ggml_permute(ctx0,
2136-
ggml_cpy(ctx0,
2137-
Kcur,
2138-
ggml_new_tensor_3d(ctx0, wctx.itype, n_state_head, n_head, n_ctx)),
2133+
ggml_cast(ctx0,
2134+
ggml_reshape_3d(ctx0, Kcur, n_state_head, n_head, n_ctx),
2135+
wctx.itype),
21392136
0, 2, 1, 3);
21402137

21412138
// K * Q
@@ -2144,22 +2141,19 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
21442141
struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
21452142

21462143
struct ggml_tensor * V =
2147-
ggml_cpy(ctx0,
2144+
ggml_cast(ctx0,
21482145
ggml_permute(ctx0,
21492146
ggml_reshape_3d(ctx0,
21502147
Vcur,
21512148
n_state_head, n_head, n_ctx),
21522149
1, 2, 0, 3),
2153-
ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state_head, n_head)
2154-
);
2150+
wctx.itype);
21552151

21562152
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
21572153

21582154
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
21592155

2160-
cur = ggml_cpy(ctx0,
2161-
KQV_merged,
2162-
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
2156+
cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_ctx);
21632157
}
21642158
}
21652159

@@ -2189,11 +2183,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
21892183
layer.mlp_ln_b);
21902184
}
21912185

2192-
#ifdef WHISPER_USE_FLASH_FF
2193-
cur = ggml_flash_ff(ctx0,
2194-
ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
2195-
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
2196-
#else
21972186
// fully connected
21982187
cur = ggml_mul_mat(ctx0,
21992188
layer.mlp_0_w,
@@ -2210,7 +2199,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
22102199
cur);
22112200

22122201
cur = ggml_add(ctx0, cur, layer.mlp_1_b);
2213-
#endif
22142202
}
22152203

22162204
inpL = ggml_add(ctx0, cur, inpFF);
@@ -2586,9 +2574,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
25862574

25872575
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
25882576

2589-
cur = ggml_cpy(ctx0,
2590-
KQV_merged,
2591-
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
2577+
cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
25922578
}
25932579
}
25942580

@@ -2695,9 +2681,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
26952681

26962682
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
26972683

2698-
cur = ggml_cpy(ctx0,
2699-
KQV_merged,
2700-
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
2684+
cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
27012685
}
27022686
}
27032687

0 commit comments

Comments
 (0)