Skip to content

Commit f279daf

Browse files
committed
llama : expose API to retrieve devices used by model.
It's useful from the library to be able to do things like list the features being used by the loaded model.
1 parent 745aa53 commit f279daf

File tree

3 files changed

+20
-0
lines changed

3 files changed

+20
-0
lines changed

include/llama.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,8 @@ extern "C" {
499499
LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx);
500500
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
501501

502+
LLAMA_API size_t llama_n_backends(const struct llama_context * ctx);
503+
LLAMA_API size_t llama_get_backends(const struct llama_context * ctx, ggml_backend_t * out_buf, size_t out_len);
502504
DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead");
503505

504506
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
@@ -510,6 +512,7 @@ extern "C" {
510512
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
511513
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
512514
LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model);
515+
LLAMA_API const ggml_backend_dev_t * llama_model_get_devices (const struct llama_model * model, size_t * out_len);
513516

514517
// Get the model's RoPE frequency scaling factor
515518
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);

src/llama-context.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2428,6 +2428,18 @@ llama_memory_t llama_get_memory(const struct llama_context * ctx) {
24282428
return ctx->get_memory();
24292429
}
24302430

2431+
size_t llama_n_backends(const struct llama_context * ctx) {
2432+
return ctx->backends.size();
2433+
}
2434+
2435+
size_t llama_get_backends(const struct llama_context * ctx, ggml_backend_t * out, size_t out_len) {
2436+
size_t return_len = std::min(ctx->backends.size(), out_len);
2437+
for (size_t i = 0; i < return_len; i++) {
2438+
out[i] = ctx->backends[i].get();
2439+
}
2440+
return return_len;
2441+
}
2442+
24312443
void llama_memory_clear(llama_memory_t mem, bool data) {
24322444
if (!mem) {
24332445
return;

src/llama-model.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13628,6 +13628,11 @@ const char * llama_model_cls_label(const struct llama_model * model, uint32_t i)
1362813628
return nullptr;
1362913629
}
1363013630

13631+
const ggml_backend_dev_t * llama_model_get_devices(const struct llama_model * model, size_t * out_len) {
13632+
*out_len = model->devices.size();
13633+
return model->devices.data();
13634+
}
13635+
1363113636
// deprecated
1363213637
int32_t llama_n_ctx_train(const llama_model * model) {
1363313638
return llama_model_n_ctx_train(model);

0 commit comments

Comments
 (0)