Skip to content

Commit 29d8430

Browse files
committed
update
1 parent 45c54b6 commit 29d8430

File tree

2 files changed

+28
-22
lines changed

2 files changed

+28
-22
lines changed

ggml/src/ggml-cann/common.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
#include <thread>
3838
#include <unistd.h>
3939
#include <functional>
40-
#include <set>
40+
#include <optional>
4141

4242
#include "../include/ggml-cann.h"
4343
#include "../include/ggml.h"
@@ -104,7 +104,8 @@ const ggml_cann_device_info& ggml_cann_info();
104104
void ggml_cann_set_device(int32_t device);
105105
int32_t ggml_cann_get_device();
106106

107-
static std::string to_lower_case(const char* env_var);
107+
std::optional<std::string> get_env(const std::string& name);
108+
bool parse_bool(const std::string& value);
108109

109110
/**
110111
* @brief Abstract base class for memory pools used by CANN.
@@ -358,11 +359,9 @@ struct ggml_backend_cann_context {
358359
ggml_cann_set_device(device);
359360
description = aclrtGetSocName();
360361

361-
std::string value = to_lower_case(getenv("GGML_CANN_ASYNC_MODE"));
362-
std::set<std::string> valid_values = {"on", "1", "yes", "y", "enable"};
363-
async_mode = valid_values.find(value) != valid_values.end();
362+
bool async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
364363
GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__,
365-
device, async_mode ? "ON" : "OFF");
364+
device, async_mode ? "ON" : "OFF");
366365
}
367366

368367
/**

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
#include <mutex>
3232
#include <queue>
3333
#include <chrono>
34+
#include <set>
35+
#include <optional>
3436

3537
#include "ggml-impl.h"
3638
#include "ggml-backend-impl.h"
@@ -93,15 +95,23 @@ int32_t ggml_cann_get_device() {
9395
}
9496

9597
/**
96-
* @brief Convert the value obtained from getenv to a lowercase std::string.
97-
*
98-
* @param env_var C-style string(char*)
99-
* @return A string of type std::string.
98+
* @brief Get the value of the specified environment variable (name).
99+
* if not empty, return a std::string object
100100
*/
101-
static std::string to_lower_case(const char* env_var){
102-
std::string mem_pool_type(env_var ? env_var : "");
103-
std::transform(mem_pool_type.begin(), mem_pool_type.end(), mem_pool_type.begin(), ::tolower);
104-
return mem_pool_type;
101+
std::optional<std::string> get_env(const std::string& name) {
102+
const char* val = std::getenv(name.c_str());
103+
if (!val) return std::nullopt;
104+
return std::string(val);
105+
}
106+
107+
/**
108+
* @brief Verify whether the environment variable is a valid value.
109+
*/
110+
bool parse_bool(const std::string& value) {
111+
std::string res = value;
112+
std::transform(res.begin(), res.end(), res.begin(), ::tolower);
113+
std::set<std::string> valid_values = {"on", "1", "yes", "y", "enable", "true"};
114+
return valid_values.find(res) != valid_values.end();
105115
}
106116

107117
/**
@@ -225,7 +235,7 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
225235
* @param device The device ID to associate with this buffer pool.
226236
*/
227237
explicit ggml_cann_pool_buf_prio(int device) : device(device) {
228-
disable_clean = getenv("GGML_CANN_POOL_DISABLE_CLEAN") != nullptr;
238+
disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
229239
}
230240

231241
/**
@@ -421,7 +431,7 @@ struct ggml_cann_pool_buf : public ggml_cann_pool {
421431
* @param device The device ID to associate with this buffer pool.
422432
*/
423433
explicit ggml_cann_pool_buf(int device) : device(device) {
424-
disable_clean = getenv("GGML_CANN_POOL_DISABLE_CLEAN") != nullptr;
434+
disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
425435
}
426436

427437
/**
@@ -742,20 +752,17 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
742752
*/
743753
std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
744754
int device) {
745-
const char* env_var = getenv("GGML_CANN_MEM_POOL");
746-
std::string mem_pool_type = to_lower_case(env_var);
755+
std::string mem_pool_type = get_env("GGML_CANN_MEM_POOL").value_or("");
756+
std::transform(mem_pool_type.begin(), mem_pool_type.end(), mem_pool_type.begin(), ::tolower);
747757

748758
if (mem_pool_type == "prio") {
749759
GGML_LOG_INFO("%s: device %d use buffer pool with priority queue\n", __func__, device);
750760
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_buf_prio(device));
751761
}
752762

753-
if (mem_pool_type.empty() && ggml_cann_info().devices[device].vmm) {
763+
if (ggml_cann_info().devices[device].vmm && mem_pool_type != "leg") {
754764
GGML_LOG_INFO("%s: device %d use vmm pool\n", __func__, device);
755765
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
756-
}else{
757-
GGML_LOG_INFO("%s: device %d use buffer pool\n", __func__, device);
758-
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_buf(device));
759766
}
760767

761768
GGML_LOG_INFO("%s: device %d use buffer pool\n", __func__, device);

0 commit comments

Comments
 (0)