|
17 | 17 | #include <cmath> |
18 | 18 | #include <functional> |
19 | 19 | #include <map> |
| 20 | +#include <regex> |
20 | 21 | #include <sstream> |
21 | 22 | #include <stdexcept> |
22 | 23 |
|
@@ -378,9 +379,12 @@ struct llama_model::impl { |
378 | 379 | layer_dev dev_input = {}; |
379 | 380 | layer_dev dev_output = {}; |
380 | 381 | std::vector<layer_dev> dev_layer; |
| 382 | + |
| 383 | + bool has_tensor_overrides; |
381 | 384 | }; |
382 | 385 |
|
383 | 386 | llama_model::llama_model(const llama_model_params & params) : params(params), pimpl(std::make_unique<impl>()) { |
| 387 | + pimpl->has_tensor_overrides = params.tensor_buft_overrides && params.tensor_buft_overrides[0].pattern; |
384 | 388 | } |
385 | 389 |
|
386 | 390 | llama_model::~llama_model() {} |
@@ -1571,9 +1575,26 @@ bool llama_model::load_tensors(llama_model_loader & ml) { |
1571 | 1575 | GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str()); |
1572 | 1576 | } |
1573 | 1577 |
|
1574 | | - ggml_backend_buffer_type_t buft = select_weight_buft(hparams, t_meta, op, *buft_list); |
| 1578 | + ggml_backend_buffer_type_t buft = nullptr; |
| 1579 | + |
| 1580 | + // check overrides |
| 1581 | + if (ml.tensor_buft_overrides) { |
| 1582 | + std::string tensor_name = tn.str(); |
| 1583 | + for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) { |
| 1584 | + std::regex pattern(overrides->pattern); |
| 1585 | + if (std::regex_search(tensor_name, pattern)) { |
| 1586 | + LLAMA_LOG_DEBUG("tensor %s buffer type overriden to %s\n", tensor_name.c_str(), ggml_backend_buft_name(overrides->buft)); |
| 1587 | + buft = overrides->buft; |
| 1588 | + break; |
| 1589 | + } |
| 1590 | + } |
| 1591 | + } |
| 1592 | + |
1575 | 1593 | if (!buft) { |
1576 | | - throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str())); |
| 1594 | + buft = select_weight_buft(hparams, t_meta, op, *buft_list); |
| 1595 | + if (!buft) { |
| 1596 | + throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str())); |
| 1597 | + } |
1577 | 1598 | } |
1578 | 1599 |
|
1579 | 1600 | // avoid using a host buffer when using mmap |
@@ -4151,6 +4172,10 @@ ggml_backend_buffer_type_t llama_model::select_buft(int il) const { |
4151 | 4172 | }); |
4152 | 4173 | } |
4153 | 4174 |
|
| 4175 | +bool llama_model::has_tensor_overrides() const { |
| 4176 | + return pimpl->has_tensor_overrides; |
| 4177 | +} |
| 4178 | + |
4154 | 4179 | const ggml_tensor * llama_model::get_tensor(const char * name) const { |
4155 | 4180 | auto it = std::find_if(tensors_by_name.begin(), tensors_by_name.end(), |
4156 | 4181 | [name](const std::pair<std::string, ggml_tensor *> & it) { |
@@ -12319,6 +12344,7 @@ llm_graph_result_ptr llama_model::build_graph( |
12319 | 12344 | llama_model_params llama_model_default_params() { |
12320 | 12345 | llama_model_params result = { |
12321 | 12346 | /*.devices =*/ nullptr, |
| 12347 | + /*.tensor_buft_overrides =*/ nullptr, |
12322 | 12348 | /*.n_gpu_layers =*/ 0, |
12323 | 12349 | /*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER, |
12324 | 12350 | /*.main_gpu =*/ 0, |
|
0 commit comments