Skip to content

Commit 0be1c4d

Browse files
committed
fix DTW crash
1 parent 579539a commit 0be1c4d

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

src/whisper.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,16 @@ static ggml_tensor * whisper_set_f32(struct ggml_tensor * t, float v) {
212212
return t;
213213
}
214214

215+
static ggml_tensor * whisper_set_i32(struct ggml_tensor * t, int32_t v) {
216+
GGML_ASSERT(t->type == GGML_TYPE_I32);
217+
GGML_ASSERT(ggml_is_contiguous(t));
218+
size_t nels = ggml_nelements(t);
219+
for (int64_t i = 0; i < nels; ++i) {
220+
((int32_t *) t->data)[i] = v;
221+
}
222+
return t;
223+
}
224+
215225
static float whisper_get_f32_nd(const struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
216226
GGML_ASSERT(t->type == GGML_TYPE_F32);
217227
void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
@@ -3567,7 +3577,7 @@ struct whisper_context_params whisper_context_default_params() {
35673577
/*.n_heads =*/ 0,
35683578
/*.heads =*/ NULL,
35693579
},
3570-
/*.dtw_mem_size =*/ 1024*1024*128, // TODO: probably can be removed now
3580+
/*.dtw_mem_size =*/ 1024*1024*128,
35713581
};
35723582
return result;
35733583
}
@@ -7170,7 +7180,7 @@ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
71707180
struct ggml_tensor * trace = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, N + 1, M + 1);
71717181

71727182
cost = whisper_set_f32(cost, INFINITY);
7173-
trace = whisper_set_f32(trace, -1);
7183+
trace = whisper_set_i32(trace, -1);
71747184
whisper_set_f32_nd(cost, 0, 0, 0, 0, 0.0);
71757185

71767186
// dtw
@@ -7306,9 +7316,9 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
73067316
// Our ggml buffer should be pre-allocated somewhere during init and reused
73077317
// when we call this function
73087318
struct ggml_init_params gparams = {
7309-
/*.mem_size =*/ ggml_tensor_overhead()*1024 + ggml_graph_overhead(),
7319+
/*.mem_size =*/ ctx->params.dtw_mem_size,
73107320
/*.mem_buffer =*/ NULL,
7311-
/*.no_alloc =*/ true,
7321+
/*.no_alloc =*/ false,
73127322
};
73137323
struct ggml_context * gctx = ggml_init(gparams);
73147324

@@ -7403,7 +7413,6 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
74037413
ggml_build_forward_expand(gf, w);
74047414

74057415
ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) };
7406-
ggml_backend_buffer_ptr buf { ggml_backend_alloc_ctx_tensors(gctx, backend.get()) };
74077416
ggml_backend_graph_compute(backend.get(), gf);
74087417

74097418
ggml_tensor * alignment = dtw_and_backtrace(gctx, w);

0 commit comments

Comments
 (0)