@@ -212,6 +212,16 @@ static ggml_tensor * whisper_set_f32(struct ggml_tensor * t, float v) {
212212 return t;
213213}
214214
215+ static ggml_tensor * whisper_set_i32 (struct ggml_tensor * t, int32_t v) {
216+ GGML_ASSERT (t->type == GGML_TYPE_I32);
217+ GGML_ASSERT (ggml_is_contiguous (t));
218+ size_t nels = ggml_nelements (t);
219+ for (int64_t i = 0 ; i < nels; ++i) {
220+ ((int32_t *) t->data )[i] = v;
221+ }
222+ return t;
223+ }
224+
215225static float whisper_get_f32_nd (const struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
216226 GGML_ASSERT (t->type == GGML_TYPE_F32);
217227 void * data = (char *) t->data + i0*t->nb [0 ] + i1*t->nb [1 ] + i2*t->nb [2 ] + i3*t->nb [3 ];
@@ -3567,7 +3577,7 @@ struct whisper_context_params whisper_context_default_params() {
35673577 /* .n_heads =*/ 0 ,
35683578 /* .heads =*/ NULL ,
35693579 },
3570- /* .dtw_mem_size =*/ 1024 *1024 *128 , // TODO: probably can be removed now
3580+ /* .dtw_mem_size =*/ 1024 *1024 *128 ,
35713581 };
35723582 return result;
35733583}
@@ -7170,7 +7180,7 @@ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
71707180 struct ggml_tensor * trace = ggml_new_tensor_2d (ctx, GGML_TYPE_I32, N + 1 , M + 1 );
71717181
71727182 cost = whisper_set_f32 (cost, INFINITY);
7173- trace = whisper_set_f32 (trace, -1 );
7183+ trace = whisper_set_i32 (trace, -1 );
71747184 whisper_set_f32_nd (cost, 0 , 0 , 0 , 0 , 0.0 );
71757185
71767186 // dtw
@@ -7306,9 +7316,9 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
73067316 // Our ggml buffer should be pre-allocated somewhere during init and reused
73077317 // when we call this function
73087318 struct ggml_init_params gparams = {
7309- /* .mem_size =*/ ggml_tensor_overhead ()* 1024 + ggml_graph_overhead () ,
7319+ /* .mem_size =*/ ctx-> params . dtw_mem_size ,
73107320 /* .mem_buffer =*/ NULL ,
7311- /* .no_alloc =*/ true ,
7321+ /* .no_alloc =*/ false ,
73127322 };
73137323 struct ggml_context * gctx = ggml_init (gparams);
73147324
@@ -7403,7 +7413,6 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
74037413 ggml_build_forward_expand (gf, w);
74047414
74057415 ggml_backend_ptr backend { ggml_backend_init_by_type (GGML_BACKEND_DEVICE_TYPE_CPU, nullptr ) };
7406- ggml_backend_buffer_ptr buf { ggml_backend_alloc_ctx_tensors (gctx, backend.get ()) };
74077416 ggml_backend_graph_compute (backend.get (), gf);
74087417
74097418 ggml_tensor * alignment = dtw_and_backtrace (gctx, w);
0 commit comments