Skip to content

Commit 21d890d

Browse files
authored
whisper : add support for backends with multiple ggml_backend_buffer_type (#2863)
* whisper : add support for ggml_backend_buffer_type Signed-off-by: Dan Johansson <[email protected]> * fix compile error when building on Ubuntu Signed-off-by: Dan Johansson <[email protected]> * remove copyright header from include file Signed-off-by: Dan Johansson <[email protected]> --------- Signed-off-by: Dan Johansson <[email protected]>
1 parent 0b43a02 commit 21d890d

File tree

3 files changed

+382
-188
lines changed

3 files changed

+382
-188
lines changed

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ endif()
102102

103103
add_library(whisper
104104
../include/whisper.h
105+
whisper-arch.h
105106
whisper.cpp
106107
)
107108

src/whisper-arch.h

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
#pragma once
2+
3+
#include "ggml.h"
4+
5+
#include <map>
6+
7+
enum asr_tensor {
8+
ASR_TENSOR_ENC_POS_EMBD,
9+
ASR_TENSOR_DEC_POS_EMBD,
10+
ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT,
11+
ASR_TENSOR_LN_WEIGHT,
12+
ASR_TENSOR_LN_BIAS,
13+
ASR_TENSOR_CONV1_WEIGHT,
14+
ASR_TENSOR_CONV1_BIAS,
15+
ASR_TENSOR_CONV2_WEIGHT,
16+
ASR_TENSOR_CONV2_BIAS,
17+
ASR_TENSOR_LN_POST_WEIGHT,
18+
ASR_TENSOR_LN_POST_BIAS,
19+
ASR_TENSOR_MLP_LN_WEIGHT,
20+
ASR_TENSOR_MLP_LN_BIAS,
21+
ASR_TENSOR_MLP_0_WEIGHT,
22+
ASR_TENSOR_MLP_0_BIAS,
23+
ASR_TENSOR_MLP_2_WEIGHT,
24+
ASR_TENSOR_MLP_2_BIAS,
25+
ASR_TENSOR_ATTN_LN_WEIGHT,
26+
ASR_TENSOR_ATTN_LN_BIAS,
27+
ASR_TENSOR_ATTN_QUERY_WEIGHT,
28+
ASR_TENSOR_ATTN_QUERY_BIAS,
29+
ASR_TENSOR_ATTN_KEY_WEIGHT,
30+
ASR_TENSOR_ATTN_VALUE_WEIGHT,
31+
ASR_TENSOR_ATTN_VALUE_BIAS,
32+
ASR_TENSOR_ATTN_OUT_WEIGHT,
33+
ASR_TENSOR_ATTN_OUT_BIAS,
34+
};
35+
36+
enum asr_system {
37+
ASR_SYSTEM_ENCODER,
38+
ASR_SYSTEM_DECODER,
39+
ASR_SYSTEM_CROSS
40+
};
41+
42+
static const std::map<asr_system, std::map<asr_tensor, const char *>> ASR_TENSOR_NAMES = {
43+
{
44+
ASR_SYSTEM_ENCODER,
45+
{
46+
{ASR_TENSOR_ENC_POS_EMBD, "encoder.positional_embedding"},
47+
{ASR_TENSOR_CONV1_WEIGHT, "encoder.conv1.weight"},
48+
{ASR_TENSOR_CONV1_BIAS, "encoder.conv1.bias"},
49+
{ASR_TENSOR_CONV2_WEIGHT, "encoder.conv2.weight"},
50+
{ASR_TENSOR_CONV2_BIAS, "encoder.conv2.bias"},
51+
{ASR_TENSOR_LN_WEIGHT, "encoder.ln_post.weight"},
52+
{ASR_TENSOR_LN_POST_BIAS, "encoder.ln_post.bias"},
53+
{ASR_TENSOR_MLP_LN_WEIGHT, "encoder.blocks.%d.mlp_ln.weight"},
54+
{ASR_TENSOR_MLP_LN_BIAS, "encoder.blocks.%d.mlp_ln.bias"},
55+
{ASR_TENSOR_MLP_0_WEIGHT, "encoder.blocks.%d.mlp.0.weight"},
56+
{ASR_TENSOR_MLP_0_BIAS, "encoder.blocks.%d.mlp.0.bias"},
57+
{ASR_TENSOR_MLP_2_WEIGHT, "encoder.blocks.%d.mlp.2.weight"},
58+
{ASR_TENSOR_MLP_2_BIAS, "encoder.blocks.%d.mlp.2.bias"},
59+
{ASR_TENSOR_ATTN_LN_WEIGHT, "encoder.blocks.%d.attn_ln.weight"},
60+
{ASR_TENSOR_ATTN_LN_BIAS, "encoder.blocks.%d.attn_ln.bias"},
61+
{ASR_TENSOR_ATTN_QUERY_WEIGHT, "encoder.blocks.%d.attn.query.weight"},
62+
{ASR_TENSOR_ATTN_QUERY_BIAS, "encoder.blocks.%d.attn.query.bias"},
63+
{ASR_TENSOR_ATTN_KEY_WEIGHT, "encoder.blocks.%d.attn.key.weight"},
64+
{ASR_TENSOR_ATTN_VALUE_WEIGHT, "encoder.blocks.%d.attn.value.weight"},
65+
{ASR_TENSOR_ATTN_VALUE_BIAS, "encoder.blocks.%d.attn.value.bias"},
66+
{ASR_TENSOR_ATTN_OUT_WEIGHT, "encoder.blocks.%d.attn.out.weight"},
67+
{ASR_TENSOR_ATTN_OUT_BIAS, "encoder.blocks.%d.attn.out.bias"},
68+
},
69+
},
70+
{
71+
ASR_SYSTEM_DECODER,
72+
{
73+
{ASR_TENSOR_DEC_POS_EMBD, "decoder.positional_embedding"},
74+
{ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT, "decoder.token_embedding.weight"},
75+
{ASR_TENSOR_LN_WEIGHT, "decoder.ln.weight"},
76+
{ASR_TENSOR_LN_BIAS, "decoder.ln.bias"},
77+
78+
{ASR_TENSOR_MLP_LN_WEIGHT, "decoder.blocks.%d.mlp_ln.weight"},
79+
{ASR_TENSOR_MLP_LN_BIAS, "decoder.blocks.%d.mlp_ln.bias"},
80+
{ASR_TENSOR_MLP_0_WEIGHT, "decoder.blocks.%d.mlp.0.weight"},
81+
{ASR_TENSOR_MLP_0_BIAS, "decoder.blocks.%d.mlp.0.bias"},
82+
{ASR_TENSOR_MLP_2_WEIGHT, "decoder.blocks.%d.mlp.2.weight"},
83+
{ASR_TENSOR_MLP_2_BIAS, "decoder.blocks.%d.mlp.2.bias"},
84+
{ASR_TENSOR_ATTN_LN_WEIGHT, "decoder.blocks.%d.attn_ln.weight"},
85+
{ASR_TENSOR_ATTN_LN_BIAS, "decoder.blocks.%d.attn_ln.bias"},
86+
{ASR_TENSOR_ATTN_QUERY_WEIGHT, "decoder.blocks.%d.attn.query.weight"},
87+
{ASR_TENSOR_ATTN_QUERY_BIAS, "decoder.blocks.%d.attn.query.bias"},
88+
{ASR_TENSOR_ATTN_KEY_WEIGHT, "decoder.blocks.%d.attn.key.weight"},
89+
{ASR_TENSOR_ATTN_VALUE_WEIGHT, "decoder.blocks.%d.attn.value.weight"},
90+
{ASR_TENSOR_ATTN_VALUE_BIAS, "decoder.blocks.%d.attn.value.bias"},
91+
{ASR_TENSOR_ATTN_OUT_WEIGHT, "decoder.blocks.%d.attn.out.weight"},
92+
{ASR_TENSOR_ATTN_OUT_BIAS, "decoder.blocks.%d.attn.out.bias"},
93+
},
94+
},
95+
{
96+
ASR_SYSTEM_CROSS,
97+
{
98+
{ASR_TENSOR_ATTN_LN_WEIGHT, "decoder.blocks.%d.cross_attn_ln.weight"},
99+
{ASR_TENSOR_ATTN_LN_BIAS, "decoder.blocks.%d.cross_attn_ln.bias"},
100+
{ASR_TENSOR_ATTN_QUERY_WEIGHT, "decoder.blocks.%d.cross_attn.query.weight"},
101+
{ASR_TENSOR_ATTN_QUERY_BIAS, "decoder.blocks.%d.cross_attn.query.bias"},
102+
{ASR_TENSOR_ATTN_KEY_WEIGHT, "decoder.blocks.%d.cross_attn.key.weight"},
103+
{ASR_TENSOR_ATTN_VALUE_WEIGHT, "decoder.blocks.%d.cross_attn.value.weight"},
104+
{ASR_TENSOR_ATTN_VALUE_BIAS, "decoder.blocks.%d.cross_attn.value.bias"},
105+
{ASR_TENSOR_ATTN_OUT_WEIGHT, "decoder.blocks.%d.cross_attn.out.weight"},
106+
{ASR_TENSOR_ATTN_OUT_BIAS, "decoder.blocks.%d.cross_attn.out.bias"},
107+
},
108+
},
109+
};
110+
111+
static const std::map<asr_tensor, ggml_op> ASR_TENSOR_INFO = {
112+
{ASR_TENSOR_ENC_POS_EMBD, GGML_OP_ADD},
113+
{ASR_TENSOR_DEC_POS_EMBD, GGML_OP_GET_ROWS},
114+
// Note: ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT is also used by GGML_OP_MAT_MUL. Need to figure out a way how to handle
115+
// weight tensors that are used by multiple different operators when extra_buffer_type implementations accelerate
116+
// more than just GGML_OP_MUL_MAT.
117+
{ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT, GGML_OP_GET_ROWS},
118+
{ASR_TENSOR_LN_WEIGHT, GGML_OP_MUL},
119+
{ASR_TENSOR_LN_BIAS, GGML_OP_ADD},
120+
{ASR_TENSOR_CONV1_WEIGHT, GGML_OP_IM2COL},
121+
{ASR_TENSOR_CONV1_BIAS, GGML_OP_ADD},
122+
{ASR_TENSOR_CONV2_WEIGHT, GGML_OP_IM2COL},
123+
{ASR_TENSOR_CONV2_BIAS, GGML_OP_ADD},
124+
{ASR_TENSOR_LN_POST_WEIGHT, GGML_OP_MUL},
125+
{ASR_TENSOR_LN_POST_BIAS, GGML_OP_ADD},
126+
{ASR_TENSOR_MLP_LN_WEIGHT, GGML_OP_MUL},
127+
{ASR_TENSOR_MLP_LN_BIAS, GGML_OP_ADD},
128+
{ASR_TENSOR_MLP_0_WEIGHT, GGML_OP_MUL_MAT},
129+
{ASR_TENSOR_MLP_0_BIAS, GGML_OP_ADD},
130+
{ASR_TENSOR_MLP_2_WEIGHT, GGML_OP_MUL_MAT},
131+
{ASR_TENSOR_MLP_2_BIAS, GGML_OP_ADD},
132+
{ASR_TENSOR_ATTN_LN_WEIGHT, GGML_OP_MUL},
133+
{ASR_TENSOR_ATTN_LN_BIAS, GGML_OP_ADD},
134+
{ASR_TENSOR_ATTN_QUERY_WEIGHT, GGML_OP_MUL_MAT},
135+
{ASR_TENSOR_ATTN_QUERY_BIAS, GGML_OP_ADD},
136+
{ASR_TENSOR_ATTN_KEY_WEIGHT, GGML_OP_MUL_MAT},
137+
{ASR_TENSOR_ATTN_VALUE_WEIGHT, GGML_OP_MUL_MAT},
138+
{ASR_TENSOR_ATTN_VALUE_BIAS, GGML_OP_ADD},
139+
{ASR_TENSOR_ATTN_OUT_WEIGHT, GGML_OP_MUL_MAT},
140+
{ASR_TENSOR_ATTN_OUT_BIAS, GGML_OP_ADD},
141+
};

0 commit comments

Comments
 (0)