|
31 | 31 | #include <mutex> |
32 | 32 | #include <queue> |
33 | 33 | #include <chrono> |
| 34 | +#include <set> |
| 35 | +#include <optional> |
34 | 36 |
|
35 | 37 | #include "ggml-impl.h" |
36 | 38 | #include "ggml-backend-impl.h" |
@@ -93,15 +95,23 @@ int32_t ggml_cann_get_device() { |
93 | 95 | } |
94 | 96 |
|
95 | 97 | /** |
96 | | - * @brief Convert the value obtained from getenv to a lowercase std::string. |
97 | | - * |
98 | | - * @param env_var C-style string(char*) |
99 | | - * @return A string of type std::string. |
| 98 | + * @brief Get the value of the specified environment variable (name). |
| 99 | + * if not empty, return a std::string object |
100 | 100 | */ |
101 | | -static std::string to_lower_case(const char* env_var){ |
102 | | - std::string mem_pool_type(env_var ? env_var : ""); |
103 | | - std::transform(mem_pool_type.begin(), mem_pool_type.end(), mem_pool_type.begin(), ::tolower); |
104 | | - return mem_pool_type; |
| 101 | +std::optional<std::string> get_env(const std::string& name) { |
| 102 | + const char* val = std::getenv(name.c_str()); |
| 103 | + if (!val) return std::nullopt; |
| 104 | + return std::string(val); |
| 105 | +} |
| 106 | + |
| 107 | +/** |
| 108 | + * @brief Verify whether the environment variable is a valid value. |
| 109 | + */ |
| 110 | +bool parse_bool(const std::string& value) { |
| 111 | + std::string res = value; |
| 112 | + std::transform(res.begin(), res.end(), res.begin(), ::tolower); |
| 113 | + std::set<std::string> valid_values = {"on", "1", "yes", "y", "enable", "true"}; |
| 114 | + return valid_values.find(res) != valid_values.end(); |
105 | 115 | } |
106 | 116 |
|
107 | 117 | /** |
@@ -225,7 +235,7 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool { |
225 | 235 | * @param device The device ID to associate with this buffer pool. |
226 | 236 | */ |
227 | 237 | explicit ggml_cann_pool_buf_prio(int device) : device(device) { |
228 | | - disable_clean = getenv("GGML_CANN_POOL_DISABLE_CLEAN") != nullptr; |
| 238 | + disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or("")); |
229 | 239 | } |
230 | 240 |
|
231 | 241 | /** |
@@ -421,7 +431,7 @@ struct ggml_cann_pool_buf : public ggml_cann_pool { |
421 | 431 | * @param device The device ID to associate with this buffer pool. |
422 | 432 | */ |
423 | 433 | explicit ggml_cann_pool_buf(int device) : device(device) { |
424 | | - disable_clean = getenv("GGML_CANN_POOL_DISABLE_CLEAN") != nullptr; |
| 434 | + disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or("")); |
425 | 435 | } |
426 | 436 |
|
427 | 437 | /** |
@@ -742,20 +752,17 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool { |
742 | 752 | */ |
743 | 753 | std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device( |
744 | 754 | int device) { |
745 | | - const char* env_var = getenv("GGML_CANN_MEM_POOL"); |
746 | | - std::string mem_pool_type = to_lower_case(env_var); |
| 755 | + std::string mem_pool_type = get_env("GGML_CANN_MEM_POOL").value_or(""); |
| 756 | + std::transform(mem_pool_type.begin(), mem_pool_type.end(), mem_pool_type.begin(), ::tolower); |
747 | 757 |
|
748 | 758 | if (mem_pool_type == "prio") { |
749 | 759 | GGML_LOG_INFO("%s: device %d use buffer pool with priority queue\n", __func__, device); |
750 | 760 | return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_buf_prio(device)); |
751 | 761 | } |
752 | 762 |
|
753 | | - if (mem_pool_type.empty() && ggml_cann_info().devices[device].vmm) { |
| 763 | + if (ggml_cann_info().devices[device].vmm && mem_pool_type != "leg") { |
754 | 764 | GGML_LOG_INFO("%s: device %d use vmm pool\n", __func__, device); |
755 | 765 | return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device)); |
756 | | - }else{ |
757 | | - GGML_LOG_INFO("%s: device %d use buffer pool\n", __func__, device); |
758 | | - return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_buf(device)); |
759 | 766 | } |
760 | 767 |
|
761 | 768 | GGML_LOG_INFO("%s: device %d use buffer pool\n", __func__, device); |
|
0 commit comments