Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 35 additions & 13 deletions src/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1235,21 +1235,36 @@ static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) {
ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);

ggml_backend_dev_t dev = nullptr;

int cnt = 0;
if (params.use_gpu) {
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
ggml_backend_t result = ggml_backend_dev_init(dev, nullptr);
if (!result) {
WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
ggml_backend_dev_t dev_cur = ggml_backend_dev_get(i);
if (ggml_backend_dev_type(dev_cur) == GGML_BACKEND_DEVICE_TYPE_GPU) {
if (cnt == 0 || cnt == params.gpu_device) {
dev = dev_cur;
}

if (++cnt > params.gpu_device) {
break;
}
return result;
}
}
}

return nullptr;
if (dev == nullptr) {
WHISPER_LOG_INFO("%s: no GPU found\n", __func__);
return nullptr;
}

WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
ggml_backend_t result = ggml_backend_dev_init(dev, nullptr);
if (!result) {
WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
}

return result;
}

static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_params & params) {
Expand Down Expand Up @@ -1283,20 +1298,27 @@ static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_pa
}

static ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) {
ggml_backend_buffer_type_t result = ggml_backend_cpu_buffer_type();

if (!params.use_gpu) {
return ggml_backend_cpu_buffer_type();
return result;
}

// if we have a GPU device - use it
int cnt = 0;
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
WHISPER_LOG_INFO("%s: using device %s (%s)\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev));
return ggml_backend_dev_buffer_type(dev);
if (cnt == 0 || cnt == params.gpu_device) {
result = ggml_backend_dev_buffer_type(dev);
}

if (++cnt > params.gpu_device) {
break;
}
}
}

return ggml_backend_cpu_buffer_type();
return result;
}

// load the model from a ggml file
Expand Down
Loading