|  | 
| 31 | 31 | #include <mutex> | 
| 32 | 32 | #include <queue> | 
| 33 | 33 | #include <chrono> | 
|  | 34 | +#include <unordered_set> | 
|  | 35 | +#include <optional> | 
| 34 | 36 | 
 | 
| 35 | 37 | #include "ggml-impl.h" | 
| 36 | 38 | #include "ggml-backend-impl.h" | 
| @@ -93,6 +95,26 @@ int32_t ggml_cann_get_device() { | 
| 93 | 95 |     return id; | 
| 94 | 96 | } | 
| 95 | 97 | 
 | 
|  | 98 | +/** | 
|  | 99 | + * @brief Get the value of the specified environment variable (name). | 
|  | 100 | + *        if not empty, return a std::string object | 
|  | 101 | + */ | 
|  | 102 | +std::optional<std::string> get_env(const std::string& name) { | 
|  | 103 | +    const char* val = std::getenv(name.c_str()); | 
|  | 104 | +    if (!val) return std::nullopt; | 
|  | 105 | +    std::string res = std::string(val); | 
|  | 106 | +    std::transform(res.begin(), res.end(), res.begin(), ::tolower); | 
|  | 107 | +    return res; | 
|  | 108 | +} | 
|  | 109 | + | 
|  | 110 | +/** | 
|  | 111 | + * @brief Verify whether the environment variable is a valid value. | 
|  | 112 | + */ | 
|  | 113 | +bool parse_bool(const std::string& value) { | 
|  | 114 | +    std::unordered_set<std::string> valid_values = {"on", "1", "yes", "y", "enable", "true"}; | 
|  | 115 | +    return valid_values.find(value) != valid_values.end(); | 
|  | 116 | +} | 
|  | 117 | + | 
| 96 | 118 | /** | 
| 97 | 119 |  * @brief Initialize the CANN device information. | 
| 98 | 120 |  * | 
| @@ -214,7 +236,7 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool { | 
| 214 | 236 |      * @param device The device ID to associate with this buffer pool. | 
| 215 | 237 |      */ | 
| 216 | 238 |     explicit ggml_cann_pool_buf_prio(int device) : device(device) { | 
| 217 |  | -        disable_clean = getenv("GGML_CANN_DISABLE_BUF_POOL_CLEAN") != nullptr; | 
|  | 239 | +        disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or("")); | 
| 218 | 240 |     } | 
| 219 | 241 | 
 | 
| 220 | 242 |     /** | 
| @@ -410,7 +432,7 @@ struct ggml_cann_pool_buf : public ggml_cann_pool { | 
| 410 | 432 |      * @param device The device ID to associate with this buffer pool. | 
| 411 | 433 |      */ | 
| 412 | 434 |     explicit ggml_cann_pool_buf(int device) : device(device) { | 
| 413 |  | -        disable_clean = getenv("GGML_CANN_DISABLE_BUF_POOL_CLEAN") != nullptr; | 
|  | 435 | +        disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or("")); | 
| 414 | 436 |     } | 
| 415 | 437 | 
 | 
| 416 | 438 |     /** | 
| @@ -731,16 +753,18 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool { | 
| 731 | 753 |  */ | 
| 732 | 754 | std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device( | 
| 733 | 755 |     int device) { | 
| 734 |  | -    bool disable_vmm = (getenv("GGML_CANN_DISABLE_VMM_POOL") != nullptr); | 
| 735 |  | -    if (!disable_vmm && ggml_cann_info().devices[device].vmm) { | 
| 736 |  | -        GGML_LOG_INFO("%s: device %d use vmm pool\n", __func__, device); | 
| 737 |  | -        return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device)); | 
| 738 |  | -    } | 
| 739 |  | -    bool enable_buf_prio = (getenv("GGML_CANN_ENABLE_BUF_PRIO_POOL") != nullptr); | 
| 740 |  | -    if (enable_buf_prio) { | 
|  | 756 | +    std::string mem_pool_type = get_env("GGML_CANN_MEM_POOL").value_or(""); | 
|  | 757 | + | 
|  | 758 | +    if (mem_pool_type == "prio") { | 
| 741 | 759 |         GGML_LOG_INFO("%s: device %d use buffer pool with priority queue\n", __func__, device); | 
| 742 | 760 |         return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_buf_prio(device)); | 
| 743 | 761 |     } | 
|  | 762 | + | 
|  | 763 | +    if (ggml_cann_info().devices[device].vmm && mem_pool_type != "leg") { | 
|  | 764 | +        GGML_LOG_INFO("%s: device %d use vmm pool\n", __func__, device); | 
|  | 765 | +        return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device)); | 
|  | 766 | +    } | 
|  | 767 | + | 
| 744 | 768 |     GGML_LOG_INFO("%s: device %d use buffer pool\n", __func__, device); | 
| 745 | 769 |     return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_buf(device)); | 
| 746 | 770 | } | 
|  | 
0 commit comments