Skip to content

Commit ee19e15

Browse files
committed
whisper : add support for ggml_backend_buffer_type
Signed-off-by: Dan Johansson <[email protected]>
1 parent fc7b1ee commit ee19e15

File tree

3 files changed

+385
-188
lines changed

3 files changed

+385
-188
lines changed

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ endif()
8282

8383
add_library(whisper
8484
../include/whisper.h
85+
whisper-arch.h
8586
whisper.cpp
8687
)
8788

src/whisper-arch.h

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <[email protected]>
2+
// SPDX-License-Identifier: MIT
3+
//
4+
5+
#pragma once
6+
7+
#include "ggml.h"
8+
9+
#include <map>
10+
11+
enum asr_tensor {
12+
ASR_TENSOR_ENC_POS_EMBD,
13+
ASR_TENSOR_DEC_POS_EMBD,
14+
ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT,
15+
ASR_TENSOR_LN_WEIGHT,
16+
ASR_TENSOR_LN_BIAS,
17+
ASR_TENSOR_CONV1_WEIGHT,
18+
ASR_TENSOR_CONV1_BIAS,
19+
ASR_TENSOR_CONV2_WEIGHT,
20+
ASR_TENSOR_CONV2_BIAS,
21+
ASR_TENSOR_LN_POST_WEIGHT,
22+
ASR_TENSOR_LN_POST_BIAS,
23+
ASR_TENSOR_MLP_LN_WEIGHT,
24+
ASR_TENSOR_MLP_LN_BIAS,
25+
ASR_TENSOR_MLP_0_WEIGHT,
26+
ASR_TENSOR_MLP_0_BIAS,
27+
ASR_TENSOR_MLP_2_WEIGHT,
28+
ASR_TENSOR_MLP_2_BIAS,
29+
ASR_TENSOR_ATTN_LN_WEIGHT,
30+
ASR_TENSOR_ATTN_LN_BIAS,
31+
ASR_TENSOR_ATTN_QUERY_WEIGHT,
32+
ASR_TENSOR_ATTN_QUERY_BIAS,
33+
ASR_TENSOR_ATTN_KEY_WEIGHT,
34+
ASR_TENSOR_ATTN_VALUE_WEIGHT,
35+
ASR_TENSOR_ATTN_VALUE_BIAS,
36+
ASR_TENSOR_ATTN_OUT_WEIGHT,
37+
ASR_TENSOR_ATTN_OUT_BIAS,
38+
};
39+
40+
enum asr_system {
41+
ASR_SYSTEM_ENCODER,
42+
ASR_SYSTEM_DECODER,
43+
ASR_SYSTEM_CROSS
44+
};
45+
46+
static const std::map<asr_system, std::map<asr_tensor, const char *>> ASR_TENSOR_NAMES = {
47+
{
48+
ASR_SYSTEM_ENCODER,
49+
{
50+
{ASR_TENSOR_ENC_POS_EMBD, "encoder.positional_embedding"},
51+
{ASR_TENSOR_CONV1_WEIGHT, "encoder.conv1.weight"},
52+
{ASR_TENSOR_CONV1_BIAS, "encoder.conv1.bias"},
53+
{ASR_TENSOR_CONV2_WEIGHT, "encoder.conv2.weight"},
54+
{ASR_TENSOR_CONV2_BIAS, "encoder.conv2.bias"},
55+
{ASR_TENSOR_LN_WEIGHT, "encoder.ln_post.weight"},
56+
{ASR_TENSOR_LN_POST_BIAS, "encoder.ln_post.bias"},
57+
{ASR_TENSOR_MLP_LN_WEIGHT, "encoder.blocks.%d.mlp_ln.weight"},
58+
{ASR_TENSOR_MLP_LN_BIAS, "encoder.blocks.%d.mlp_ln.bias"},
59+
{ASR_TENSOR_MLP_0_WEIGHT, "encoder.blocks.%d.mlp.0.weight"},
60+
{ASR_TENSOR_MLP_0_BIAS, "encoder.blocks.%d.mlp.0.bias"},
61+
{ASR_TENSOR_MLP_2_WEIGHT, "encoder.blocks.%d.mlp.2.weight"},
62+
{ASR_TENSOR_MLP_2_BIAS, "encoder.blocks.%d.mlp.2.bias"},
63+
{ASR_TENSOR_ATTN_LN_WEIGHT, "encoder.blocks.%d.attn_ln.weight"},
64+
{ASR_TENSOR_ATTN_LN_BIAS, "encoder.blocks.%d.attn_ln.bias"},
65+
{ASR_TENSOR_ATTN_QUERY_WEIGHT, "encoder.blocks.%d.attn.query.weight"},
66+
{ASR_TENSOR_ATTN_QUERY_BIAS, "encoder.blocks.%d.attn.query.bias"},
67+
{ASR_TENSOR_ATTN_KEY_WEIGHT, "encoder.blocks.%d.attn.key.weight"},
68+
{ASR_TENSOR_ATTN_VALUE_WEIGHT, "encoder.blocks.%d.attn.value.weight"},
69+
{ASR_TENSOR_ATTN_VALUE_BIAS, "encoder.blocks.%d.attn.value.bias"},
70+
{ASR_TENSOR_ATTN_OUT_WEIGHT, "encoder.blocks.%d.attn.out.weight"},
71+
{ASR_TENSOR_ATTN_OUT_BIAS, "encoder.blocks.%d.attn.out.bias"},
72+
},
73+
},
74+
{
75+
ASR_SYSTEM_DECODER,
76+
{
77+
{ASR_TENSOR_DEC_POS_EMBD, "decoder.positional_embedding"},
78+
{ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT, "decoder.token_embedding.weight"},
79+
{ASR_TENSOR_LN_WEIGHT, "decoder.ln.weight"},
80+
{ASR_TENSOR_LN_BIAS, "decoder.ln.bias"},
81+
82+
{ASR_TENSOR_MLP_LN_WEIGHT, "decoder.blocks.%d.mlp_ln.weight"},
83+
{ASR_TENSOR_MLP_LN_BIAS, "decoder.blocks.%d.mlp_ln.bias"},
84+
{ASR_TENSOR_MLP_0_WEIGHT, "decoder.blocks.%d.mlp.0.weight"},
85+
{ASR_TENSOR_MLP_0_BIAS, "decoder.blocks.%d.mlp.0.bias"},
86+
{ASR_TENSOR_MLP_2_WEIGHT, "decoder.blocks.%d.mlp.2.weight"},
87+
{ASR_TENSOR_MLP_2_BIAS, "decoder.blocks.%d.mlp.2.bias"},
88+
{ASR_TENSOR_ATTN_LN_WEIGHT, "decoder.blocks.%d.attn_ln.weight"},
89+
{ASR_TENSOR_ATTN_LN_BIAS, "decoder.blocks.%d.attn_ln.bias"},
90+
{ASR_TENSOR_ATTN_QUERY_WEIGHT, "decoder.blocks.%d.attn.query.weight"},
91+
{ASR_TENSOR_ATTN_QUERY_BIAS, "decoder.blocks.%d.attn.query.bias"},
92+
{ASR_TENSOR_ATTN_KEY_WEIGHT, "decoder.blocks.%d.attn.key.weight"},
93+
{ASR_TENSOR_ATTN_VALUE_WEIGHT, "decoder.blocks.%d.attn.value.weight"},
94+
{ASR_TENSOR_ATTN_VALUE_BIAS, "decoder.blocks.%d.attn.value.bias"},
95+
{ASR_TENSOR_ATTN_OUT_WEIGHT, "decoder.blocks.%d.attn.out.weight"},
96+
{ASR_TENSOR_ATTN_OUT_BIAS, "decoder.blocks.%d.attn.out.bias"},
97+
},
98+
},
99+
{
100+
ASR_SYSTEM_CROSS,
101+
{
102+
{ASR_TENSOR_ATTN_LN_WEIGHT, "decoder.blocks.%d.cross_attn_ln.weight"},
103+
{ASR_TENSOR_ATTN_LN_BIAS, "decoder.blocks.%d.cross_attn_ln.bias"},
104+
{ASR_TENSOR_ATTN_QUERY_WEIGHT, "decoder.blocks.%d.cross_attn.query.weight"},
105+
{ASR_TENSOR_ATTN_QUERY_BIAS, "decoder.blocks.%d.cross_attn.query.bias"},
106+
{ASR_TENSOR_ATTN_KEY_WEIGHT, "decoder.blocks.%d.cross_attn.key.weight"},
107+
{ASR_TENSOR_ATTN_VALUE_WEIGHT, "decoder.blocks.%d.cross_attn.value.weight"},
108+
{ASR_TENSOR_ATTN_VALUE_BIAS, "decoder.blocks.%d.cross_attn.value.bias"},
109+
{ASR_TENSOR_ATTN_OUT_WEIGHT, "decoder.blocks.%d.cross_attn.out.weight"},
110+
{ASR_TENSOR_ATTN_OUT_BIAS, "decoder.blocks.%d.cross_attn.out.bias"},
111+
},
112+
},
113+
};
114+
115+
static const std::map<asr_tensor, ggml_op> ASR_TENSOR_INFO = {
116+
{ASR_TENSOR_ENC_POS_EMBD, GGML_OP_ADD},
117+
{ASR_TENSOR_DEC_POS_EMBD, GGML_OP_GET_ROWS},
118+
// Note: ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT is also used by GGML_OP_MAT_MUL. Need to figure out a way how to handle
119+
// weight tensors that are used by multiple different operators when extra_buffer_type implementations accelerate
120+
// more than just GGML_OP_MUL_MAT.
121+
{ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT, GGML_OP_GET_ROWS},
122+
{ASR_TENSOR_LN_WEIGHT, GGML_OP_MUL},
123+
{ASR_TENSOR_LN_BIAS, GGML_OP_ADD},
124+
{ASR_TENSOR_CONV1_WEIGHT, GGML_OP_IM2COL},
125+
{ASR_TENSOR_CONV1_BIAS, GGML_OP_ADD},
126+
{ASR_TENSOR_CONV2_WEIGHT, GGML_OP_IM2COL},
127+
{ASR_TENSOR_CONV2_BIAS, GGML_OP_ADD},
128+
{ASR_TENSOR_LN_POST_WEIGHT, GGML_OP_MUL},
129+
{ASR_TENSOR_LN_POST_BIAS, GGML_OP_ADD},
130+
{ASR_TENSOR_MLP_LN_WEIGHT, GGML_OP_MUL},
131+
{ASR_TENSOR_MLP_LN_BIAS, GGML_OP_ADD},
132+
{ASR_TENSOR_MLP_0_WEIGHT, GGML_OP_MUL_MAT},
133+
{ASR_TENSOR_MLP_0_BIAS, GGML_OP_ADD},
134+
{ASR_TENSOR_MLP_2_WEIGHT, GGML_OP_MUL_MAT},
135+
{ASR_TENSOR_MLP_2_BIAS, GGML_OP_ADD},
136+
{ASR_TENSOR_ATTN_LN_WEIGHT, GGML_OP_MUL},
137+
{ASR_TENSOR_ATTN_LN_BIAS, GGML_OP_ADD},
138+
{ASR_TENSOR_ATTN_QUERY_WEIGHT, GGML_OP_MUL_MAT},
139+
{ASR_TENSOR_ATTN_QUERY_BIAS, GGML_OP_ADD},
140+
{ASR_TENSOR_ATTN_KEY_WEIGHT, GGML_OP_MUL_MAT},
141+
{ASR_TENSOR_ATTN_VALUE_WEIGHT, GGML_OP_MUL_MAT},
142+
{ASR_TENSOR_ATTN_VALUE_BIAS, GGML_OP_ADD},
143+
{ASR_TENSOR_ATTN_OUT_WEIGHT, GGML_OP_MUL_MAT},
144+
{ASR_TENSOR_ATTN_OUT_BIAS, GGML_OP_ADD},
145+
};

0 commit comments

Comments
 (0)