11#include " whisper.h"
22
3- #ifdef WHISPER_USE_COREML
4- #include " coreml/whisper-encoder.h"
5- #endif
6-
73#include " ggml-cpu.h"
84
9- #ifdef GGML_USE_METAL
10- #include " ggml-metal.h"
11- #endif
12-
13- #ifdef GGML_USE_CUDA
14- #include " ggml-cuda.h"
15- #endif
16-
17- #ifdef GGML_USE_SYCL
18- #include " ggml-sycl.h"
19- #endif
20-
21- #ifdef GGML_USE_VULKAN
22- #include " ggml-vulkan.h"
23- #endif
5+ #include " ggml.h"
6+ #include " ggml-alloc.h"
7+ #include " ggml-backend.h"
248
25- #ifdef GGML_USE_BLAS
26- #include " ggml-blas .h"
9+ #ifdef WHISPER_USE_COREML
10+ #include " coreml/whisper-encoder .h"
2711#endif
2812
2913#ifdef WHISPER_USE_OPENVINO
3014#include " openvino/whisper-openvino-encoder.h"
3115#endif
3216
33- #ifdef GGML_USE_CANN
34- #include " ggml-cann.h"
35- #endif
36-
37- #include " ggml.h"
38- #include " ggml-alloc.h"
39- #include " ggml-backend.h"
40-
4117#include < atomic>
4218#include < algorithm>
4319#include < cassert>
@@ -195,14 +171,13 @@ static bool ggml_graph_compute_helper(
195171
196172 for (int i = 0 ; i < ggml_backend_sched_get_n_backends (sched); ++i) {
197173 ggml_backend_t backend = ggml_backend_sched_get_backend (sched, i);
198- if ( ggml_backend_is_cpu ( backend)) {
199- ggml_backend_cpu_set_n_threads (backend, n_threads) ;
200- }
201- # ifdef GGML_USE_BLAS
202- if (ggml_backend_is_blas (backend) ) {
203- ggml_backend_blas_set_n_threads (backend, n_threads);
174+ ggml_backend_dev_t dev = ggml_backend_get_device ( backend);
175+ ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg (dev) : nullptr ;
176+
177+ auto * fn_set_n_threads = ( ggml_backend_set_n_threads_t ) ggml_backend_reg_get_proc_address (reg, " ggml_backend_set_n_threads " );
178+ if (fn_set_n_threads ) {
179+ fn_set_n_threads (backend, n_threads);
204180 }
205- #endif
206181 }
207182
208183 bool t = ggml_backend_sched_graph_compute (sched, graph) == GGML_STATUS_SUCCESS;
@@ -1256,67 +1231,23 @@ static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
12561231}
12571232
12581233static ggml_backend_t whisper_backend_init_gpu (const whisper_context_params & params) {
1259- ggml_backend_t result = NULL ;
1260-
12611234 ggml_log_set (g_state.log_callback , g_state.log_callback_user_data );
12621235
1263- #ifdef GGML_USE_CUDA
12641236 if (params.use_gpu ) {
1265- WHISPER_LOG_INFO (" %s: using CUDA backend\n " , __func__);
1266- result = ggml_backend_cuda_init (params.gpu_device );
1267- if (!result) {
1268- WHISPER_LOG_ERROR (" %s: ggml_backend_cuda_init() failed\n " , __func__);
1269- }
1270- }
1271- #endif
1272-
1273- #ifdef GGML_USE_METAL
1274- if (params.use_gpu ) {
1275- WHISPER_LOG_INFO (" %s: using Metal backend\n " , __func__);
1276- result = ggml_backend_metal_init ();
1277- if (!result) {
1278- WHISPER_LOG_ERROR (" %s: ggml_backend_metal_init() failed\n " , __func__);
1279- } else if (!ggml_backend_metal_supports_family (result, 7 )) {
1280- WHISPER_LOG_ERROR (" %s: Metal GPU does not support family 7 - falling back to CPU\n " , __func__);
1281- ggml_backend_free (result);
1282- result = NULL ;
1283- }
1284- }
1285- #endif
1286-
1287- #ifdef GGML_USE_SYCL
1288- if (params.use_gpu ) {
1289- WHISPER_LOG_INFO (" %s: using SYCL backend\n " , __func__);
1290- result = ggml_backend_sycl_init (params.gpu_device );
1291- if (!result) {
1292- WHISPER_LOG_ERROR (" %s: ggml_backend_sycl_init() failed\n " , __func__);
1293- }
1294- }
1295- #endif
1296-
1297- #ifdef GGML_USE_VULKAN
1298- if (params.use_gpu ) {
1299- WHISPER_LOG_INFO (" %s: using Vulkan backend\n " , __func__);
1300- result = ggml_backend_vk_init (params.gpu_device );
1301- if (!result) {
1302- WHISPER_LOG_ERROR (" %s: ggml_backend_vk_init() failed\n " , __func__);
1303- }
1304- }
1305- #endif
1306-
1307- #ifdef GGML_USE_CANN
1308- if (params.use_gpu ) {
1309- WHISPER_LOG_INFO (" %s: using CANN backend\n " , __func__);
1310- result = ggml_backend_cann_init (params.gpu_device );
1311- if (!result) {
1312- WHISPER_LOG_ERROR (" %s: ggml_backend_cann_init() failed\n " , __func__);
1237+ for (size_t i = 0 ; i < ggml_backend_dev_count (); ++i) {
1238+ ggml_backend_dev_t dev = ggml_backend_dev_get (i);
1239+ if (ggml_backend_dev_type (dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
1240+ WHISPER_LOG_INFO (" %s: using %s backend\n " , __func__, ggml_backend_dev_name (dev));
1241+ ggml_backend_t result = ggml_backend_dev_init (dev, nullptr );
1242+ if (!result) {
1243+ WHISPER_LOG_ERROR (" %s: failed to initialize %s backend\n " , __func__, ggml_backend_dev_name (dev));
1244+ }
1245+ return result;
1246+ }
13131247 }
13141248 }
1315- #endif
13161249
1317- GGML_UNUSED (params);
1318-
1319- return result;
1250+ return nullptr ;
13201251}
13211252
13221253static std::vector<ggml_backend_t > whisper_backend_init (const whisper_context_params & params) {
@@ -1328,17 +1259,19 @@ static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_pa
13281259 result.push_back (backend_gpu);
13291260 }
13301261
1331- #ifdef GGML_USE_BLAS
1332- {
1333- WHISPER_LOG_INFO (" %s: using BLAS backend\n " , __func__);
1334- ggml_backend_t backend_blas = ggml_backend_blas_init ();
1335- if (!backend_blas) {
1336- WHISPER_LOG_ERROR (" %s: ggml_backend_blas_init() failed\n " , __func__);
1337- } else {
1338- result.push_back (backend_blas);
1262+ // ACCEL backends
1263+ for (size_t i = 0 ; i < ggml_backend_dev_count (); ++i) {
1264+ ggml_backend_dev_t dev = ggml_backend_dev_get (i);
1265+ if (ggml_backend_dev_type (dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
1266+ WHISPER_LOG_INFO (" %s: using %s backend\n " , __func__, ggml_backend_dev_name (dev));
1267+ ggml_backend_t backend = ggml_backend_dev_init (dev, nullptr );
1268+ if (!backend) {
1269+ WHISPER_LOG_ERROR (" %s: failed to initialize %s backend\n " , __func__, ggml_backend_dev_name (dev));
1270+ continue ;
1271+ }
1272+ result.push_back (backend);
13391273 }
13401274 }
1341- #endif
13421275
13431276 GGML_UNUSED (params);
13441277
@@ -1348,33 +1281,20 @@ static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_pa
13481281}
13491282
13501283static ggml_backend_buffer_type_t whisper_default_buffer_type (const whisper_context_params & params) {
1351- ggml_backend_buffer_type_t result = nullptr ;
1352-
1353- params.use_gpu || (result = ggml_backend_cpu_buffer_type ());
1354-
1355- #ifdef GGML_USE_CUDA
1356- result || (result = ggml_backend_cuda_buffer_type (params.gpu_device ));
1357- #endif
1358-
1359- #ifdef GGML_USE_METAL
1360- result || (result = ggml_backend_metal_buffer_type ());
1361- #endif
1362-
1363- #ifdef GGML_USE_SYCL
1364- result || (result = ggml_backend_sycl_buffer_type (params.gpu_device ));
1365- #endif
1366-
1367- #ifdef GGML_USE_VULKAN
1368- result || (result = ggml_backend_vk_buffer_type (params.gpu_device ));
1369- #endif
1370-
1371- #ifdef GGML_USE_CANN
1372- result || (result == ggml_backend_cann_buffer_type (params.gpu_device ));
1373- #endif
1284+ if (!params.use_gpu ) {
1285+ return ggml_backend_cpu_buffer_type ();
1286+ }
13741287
1375- result || (result = ggml_backend_cpu_buffer_type ());
1288+ // if we have a GPU device - use it
1289+ for (size_t i = 0 ; i < ggml_backend_dev_count (); ++i) {
1290+ ggml_backend_dev_t dev = ggml_backend_dev_get (i);
1291+ if (ggml_backend_dev_type (dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
1292+ WHISPER_LOG_INFO (" %s: using device %s (%s)\n " , __func__, ggml_backend_dev_name (dev), ggml_backend_dev_description (dev));
1293+ return ggml_backend_dev_buffer_type (dev);
1294+ }
1295+ }
13761296
1377- return result ;
1297+ return ggml_backend_cpu_buffer_type () ;
13781298}
13791299
13801300// load the model from a ggml file
@@ -3668,8 +3588,7 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_
36683588 WHISPER_LOG_INFO (" %s: flash attn = %d\n " , __func__, params.flash_attn );
36693589 WHISPER_LOG_INFO (" %s: gpu_device = %d\n " , __func__, params.gpu_device );
36703590 WHISPER_LOG_INFO (" %s: dtw = %d\n " , __func__, params.dtw_token_timestamps );
3671-
3672- // TODO: temporary call to force backend registry initialization
3591+ WHISPER_LOG_INFO (" %s: devices = %zu\n " , __func__, ggml_backend_dev_count ());
36733592 WHISPER_LOG_INFO (" %s: backends = %zu\n " , __func__, ggml_backend_reg_count ());
36743593
36753594 whisper_context * ctx = new whisper_context;
@@ -7427,6 +7346,11 @@ static void whisper_log_internal(ggml_log_level level, const char * format, ...)
74277346static void whisper_log_callback_default (ggml_log_level level, const char * text, void * user_data) {
74287347 (void ) level;
74297348 (void ) user_data;
7349+ #ifndef WHISPER_DEBUG
7350+ if (level == GGML_LOG_LEVEL_DEBUG) {
7351+ return ;
7352+ }
7353+ #endif
74307354 fputs (text, stderr);
74317355 fflush (stderr);
74327356}
0 commit comments