Skip to content

Commit 2628228

Browse files
bachelor-douggerganov
authored andcommitted
CANN: Simplify the environment variable setting(#13104)
* Simplify the environment variable setting to specify the memory pool type. * Adjust the GGML_CANN_ASYNC_MODE setting to accept yes, enable, 1, or on (case-insensitive) as valid options. * update * fix CI * update * delete whitespace * fix according to review * update CANN.md * update CANN.md
1 parent 4737a8c commit 2628228

File tree

2 files changed

+39
-10
lines changed

2 files changed

+39
-10
lines changed

ggml/src/ggml-cann/common.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include <thread>
3838
#include <unistd.h>
3939
#include <functional>
40+
#include <optional>
4041

4142
#include "../include/ggml-cann.h"
4243
#include "../include/ggml.h"
@@ -103,6 +104,9 @@ const ggml_cann_device_info& ggml_cann_info();
103104
void ggml_cann_set_device(int32_t device);
104105
int32_t ggml_cann_get_device();
105106

107+
std::optional<std::string> get_env(const std::string& name);
108+
bool parse_bool(const std::string& value);
109+
106110
/**
107111
* @brief Abstract base class for memory pools used by CANN.
108112
*/
@@ -354,7 +358,8 @@ struct ggml_backend_cann_context {
354358
: device(device), name("CANN" + std::to_string(device)), task_queue(1024, device) {
355359
ggml_cann_set_device(device);
356360
description = aclrtGetSocName();
357-
async_mode = (getenv("GGML_CANN_ASYNC_MODE") != nullptr);
361+
362+
bool async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
358363
GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__,
359364
device, async_mode ? "ON" : "OFF");
360365
}

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
#include <mutex>
3232
#include <queue>
3333
#include <chrono>
34+
#include <unordered_set>
35+
#include <optional>
3436

3537
#include "ggml-impl.h"
3638
#include "ggml-backend-impl.h"
@@ -93,6 +95,26 @@ int32_t ggml_cann_get_device() {
9395
return id;
9496
}
9597

98+
/**
99+
* @brief Get the value of the specified environment variable (name).
100+
* if not empty, return a std::string object
101+
*/
102+
std::optional<std::string> get_env(const std::string& name) {
103+
const char* val = std::getenv(name.c_str());
104+
if (!val) return std::nullopt;
105+
std::string res = std::string(val);
106+
std::transform(res.begin(), res.end(), res.begin(), ::tolower);
107+
return res;
108+
}
109+
110+
/**
111+
* @brief Verify whether the environment variable is a valid value.
112+
*/
113+
bool parse_bool(const std::string& value) {
114+
std::unordered_set<std::string> valid_values = {"on", "1", "yes", "y", "enable", "true"};
115+
return valid_values.find(value) != valid_values.end();
116+
}
117+
96118
/**
97119
* @brief Initialize the CANN device information.
98120
*
@@ -214,7 +236,7 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
214236
* @param device The device ID to associate with this buffer pool.
215237
*/
216238
explicit ggml_cann_pool_buf_prio(int device) : device(device) {
217-
disable_clean = getenv("GGML_CANN_DISABLE_BUF_POOL_CLEAN") != nullptr;
239+
disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
218240
}
219241

220242
/**
@@ -410,7 +432,7 @@ struct ggml_cann_pool_buf : public ggml_cann_pool {
410432
* @param device The device ID to associate with this buffer pool.
411433
*/
412434
explicit ggml_cann_pool_buf(int device) : device(device) {
413-
disable_clean = getenv("GGML_CANN_DISABLE_BUF_POOL_CLEAN") != nullptr;
435+
disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
414436
}
415437

416438
/**
@@ -731,16 +753,18 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
731753
*/
732754
std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
733755
int device) {
734-
bool disable_vmm = (getenv("GGML_CANN_DISABLE_VMM_POOL") != nullptr);
735-
if (!disable_vmm && ggml_cann_info().devices[device].vmm) {
736-
GGML_LOG_INFO("%s: device %d use vmm pool\n", __func__, device);
737-
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
738-
}
739-
bool enable_buf_prio = (getenv("GGML_CANN_ENABLE_BUF_PRIO_POOL") != nullptr);
740-
if (enable_buf_prio) {
756+
std::string mem_pool_type = get_env("GGML_CANN_MEM_POOL").value_or("");
757+
758+
if (mem_pool_type == "prio") {
741759
GGML_LOG_INFO("%s: device %d use buffer pool with priority queue\n", __func__, device);
742760
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_buf_prio(device));
743761
}
762+
763+
if (ggml_cann_info().devices[device].vmm && mem_pool_type != "leg") {
764+
GGML_LOG_INFO("%s: device %d use vmm pool\n", __func__, device);
765+
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
766+
}
767+
744768
GGML_LOG_INFO("%s: device %d use buffer pool\n", __func__, device);
745769
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_buf(device));
746770
}

0 commit comments

Comments
 (0)