@@ -1235,21 +1235,36 @@ static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
12351235static ggml_backend_t whisper_backend_init_gpu (const whisper_context_params & params) {
12361236 ggml_log_set (g_state.log_callback , g_state.log_callback_user_data );
12371237
1238+ ggml_backend_dev_t dev = nullptr ;
1239+
1240+ int cnt = 0 ;
12381241 if (params.use_gpu ) {
12391242 for (size_t i = 0 ; i < ggml_backend_dev_count (); ++i) {
1240- ggml_backend_dev_t dev = ggml_backend_dev_get (i);
1241- if (ggml_backend_dev_type (dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
1242- WHISPER_LOG_INFO (" %s: using %s backend\n " , __func__, ggml_backend_dev_name (dev));
1243- ggml_backend_t result = ggml_backend_dev_init (dev, nullptr );
1244- if (!result) {
1245- WHISPER_LOG_ERROR (" %s: failed to initialize %s backend\n " , __func__, ggml_backend_dev_name (dev));
1243+ ggml_backend_dev_t dev_cur = ggml_backend_dev_get (i);
1244+ if (ggml_backend_dev_type (dev_cur) == GGML_BACKEND_DEVICE_TYPE_GPU) {
1245+ if (cnt == 0 || cnt == params.gpu_device ) {
1246+ dev = dev_cur;
1247+ }
1248+
1249+ if (++cnt > params.gpu_device ) {
1250+ break ;
12461251 }
1247- return result;
12481252 }
12491253 }
12501254 }
12511255
1252- return nullptr ;
1256+ if (dev == nullptr ) {
1257+ WHISPER_LOG_INFO (" %s: no GPU found\n " , __func__);
1258+ return nullptr ;
1259+ }
1260+
1261+ WHISPER_LOG_INFO (" %s: using %s backend\n " , __func__, ggml_backend_dev_name (dev));
1262+ ggml_backend_t result = ggml_backend_dev_init (dev, nullptr );
1263+ if (!result) {
1264+ WHISPER_LOG_ERROR (" %s: failed to initialize %s backend\n " , __func__, ggml_backend_dev_name (dev));
1265+ }
1266+
1267+ return result;
12531268}
12541269
12551270static std::vector<ggml_backend_t > whisper_backend_init (const whisper_context_params & params) {
@@ -1283,20 +1298,27 @@ static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_pa
12831298}
12841299
12851300static ggml_backend_buffer_type_t whisper_default_buffer_type (const whisper_context_params & params) {
1301+ ggml_backend_buffer_type_t result = ggml_backend_cpu_buffer_type ();
1302+
12861303 if (!params.use_gpu ) {
1287- return ggml_backend_cpu_buffer_type () ;
1304+ return result ;
12881305 }
12891306
1290- // if we have a GPU device - use it
1307+ int cnt = 0 ;
12911308 for (size_t i = 0 ; i < ggml_backend_dev_count (); ++i) {
12921309 ggml_backend_dev_t dev = ggml_backend_dev_get (i);
12931310 if (ggml_backend_dev_type (dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
1294- WHISPER_LOG_INFO (" %s: using device %s (%s)\n " , __func__, ggml_backend_dev_name (dev), ggml_backend_dev_description (dev));
1295- return ggml_backend_dev_buffer_type (dev);
1311+ if (cnt == 0 || cnt == params.gpu_device ) {
1312+ result = ggml_backend_dev_buffer_type (dev);
1313+ }
1314+
1315+ if (++cnt > params.gpu_device ) {
1316+ break ;
1317+ }
12961318 }
12971319 }
12981320
1299- return ggml_backend_cpu_buffer_type () ;
1321+ return result ;
13001322}
13011323
13021324// load the model from a ggml file
0 commit comments