Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/backend/CANN.md
Original file line number Diff line number Diff line change
Expand Up @@ -314,3 +314,7 @@ Converting the matmul weight format from ND to NZ to improve performance. Enable
### GGML_CANN_ACL_GRAPH

Operators are executed using ACL graph execution, rather than in op-by-op (eager) mode. Enabled by default.

### GGML_CANN_GRAPH_CACHE_CAPACITY

Maximum number of compiled CANN graphs kept in the LRU cache, default is 12. When the number of cached graphs exceeds this capacity, the least recently used graph will be evicted.
63 changes: 62 additions & 1 deletion ggml/src/ggml-cann/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include <unistd.h>
#include <functional>
#include <optional>
#include <list>

#include "../include/ggml-cann.h"
#include "../include/ggml.h"
Expand Down Expand Up @@ -358,6 +359,66 @@ struct ggml_cann_graph {

std::vector<ggml_graph_node_properties> ggml_graph_properties;
};

/**
* @brief LRU cache for managing ggml_cann_graph objects.
*
* This class maintains a list of shared_ptr to ggml_cann_graph objects
* and enforces a maximum capacity. It provides methods to push new graphs,
* move existing graphs to the front (most recently used), and clear the cache.
*/
struct ggml_cann_graph_lru_cache {
size_t capacity; /**< Maximum number of graphs in the cache. */

std::list<std::shared_ptr<ggml_cann_graph>> cache_list; /**< List storing cached graphs. */

std::shared_ptr<ggml_cann_graph> matched_graph = nullptr; /**< Pointer to a recently matched graph. */

ggml_cann_graph_lru_cache() {
std::string env_val = get_env("GGML_CANN_GRAPH_CACHE_CAPACITY").value_or("12");
try {
capacity = std::stoul(env_val);
} catch (...) {
capacity = 12; // fallback to default if invalid
}
}

/**
* @brief Push a new graph to the front of the cache.
* If the cache exceeds capacity, the least recently used graph is removed.
* @param new_node Shared pointer to the new ggml_cann_graph to cache.
*/
void push(std::shared_ptr<ggml_cann_graph> new_node) {
if (cache_list.size() >= capacity) {
cache_list.pop_back();
}

cache_list.push_front(new_node);
}

/**
* @brief Move an existing graph to the front of the cache.
* @param node Shared pointer to the ggml_cann_graph to move.
*/
void move_to_front(std::shared_ptr<ggml_cann_graph> node) {
cache_list.remove(node);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete a list in array will go through all elements in array. It's better to use priority queue

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current implementation has a time complexity of O(n), but even if I switch to a priority queue, it would still require a full traversal. I plan to add a map member variable to reduce the time complexity to O(1).

cache_list.push_front(node);
}

/**
* @brief Clear all graphs from the cache.
*/
void clear() {
cache_list.clear();
}

/**
* @brief Destructor that clears the cache upon object destruction.
*/
~ggml_cann_graph_lru_cache() {
clear();
}
};
#endif // USE_ACL_GRAPH

struct ggml_cann_rope_cache {
Expand Down Expand Up @@ -394,7 +455,7 @@ struct ggml_backend_cann_context {
aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
#ifdef USE_ACL_GRAPH
/// Cached CANN ACL graph used for executing the current ggml computation graph.
std::unique_ptr<ggml_cann_graph> cann_graph;
ggml_cann_graph_lru_cache graph_lru_cache;
bool acl_graph_mode = true;
#endif
cann_task_queue task_queue;
Expand Down
88 changes: 50 additions & 38 deletions ggml/src/ggml-cann/ggml-cann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2140,21 +2140,31 @@ static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
* @param cgraph The ggml computational graph.
*/
static void set_ggml_graph_node_properties(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) {
for (int node_idx = 0; node_idx < cgraph->n_nodes; node_idx++) {
ggml_tensor * node = cgraph->nodes[node_idx];
cann_ctx->cann_graph->ggml_graph_properties[node_idx].node_address = node->data;
cann_ctx->cann_graph->ggml_graph_properties[node_idx].node_op = node->op;

for (int dim = 0; dim < GGML_MAX_DIMS; dim++) {
cann_ctx->cann_graph->ggml_graph_properties[node_idx].ne[dim] = node->ne[dim];
cann_ctx->cann_graph->ggml_graph_properties[node_idx].nb[dim] = node->nb[dim];
}
for (int src = 0; src < GGML_MAX_SRC; src++) {
cann_ctx->cann_graph->ggml_graph_properties[node_idx].src_address[src] =
node->src[src] ? node->src[src]->data : nullptr;
std::shared_ptr<ggml_cann_graph> &matched_graph = cann_ctx->graph_lru_cache.matched_graph;
if (!matched_graph) {
matched_graph.reset(new ggml_cann_graph());
matched_graph->ggml_graph_properties.resize(cgraph->n_nodes);
for (int node_idx = 0; node_idx < cgraph->n_nodes; node_idx++) {
ggml_tensor * node = cgraph->nodes[node_idx];
matched_graph->ggml_graph_properties[node_idx].node_address = node->data;
matched_graph->ggml_graph_properties[node_idx].node_op = node->op;

for (int dim = 0; dim < GGML_MAX_DIMS; dim++) {
matched_graph->ggml_graph_properties[node_idx].ne[dim] = node->ne[dim];
matched_graph->ggml_graph_properties[node_idx].nb[dim] = node->nb[dim];
}
for (int src = 0; src < GGML_MAX_SRC; src++) {
matched_graph->ggml_graph_properties[node_idx].src_address[src] =
node->src[src] ? node->src[src]->data : nullptr;
}
memcpy(matched_graph->ggml_graph_properties[node_idx].op_params, node->op_params, GGML_MAX_OP_PARAMS);
}
memcpy(cann_ctx->cann_graph->ggml_graph_properties[node_idx].op_params, node->op_params, GGML_MAX_OP_PARAMS);

cann_ctx->graph_lru_cache.push(matched_graph);
} else {
cann_ctx->graph_lru_cache.move_to_front(matched_graph);
}
return;
}

/**
Expand Down Expand Up @@ -2209,21 +2219,29 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
* @return true if an update is required; false otherwise.
*/
static bool is_cann_graph_update_required(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) {
// The number of nodes is different, so the graph needs to be reconstructed.
if (cann_ctx->cann_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
cann_ctx->cann_graph->ggml_graph_properties.resize(cgraph->n_nodes);
return true;
}

// The number of nodes is the same; iterate over each node to check whether they match.
for (int i = 0; i < cgraph->n_nodes; i++) {
bool has_matching_properties = ggml_graph_node_has_matching_properties(
cgraph->nodes[i], &cann_ctx->cann_graph->ggml_graph_properties[i]);
if(!has_matching_properties) {
return true;
ggml_cann_graph_lru_cache &lru_cache = cann_ctx->graph_lru_cache;
for (auto &graph_ptr : lru_cache.cache_list) {
// The number of nodes is different, so the graph needs to be reconstructed.
if (graph_ptr->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
continue;
}
// The number of nodes is the same; iterate over each node to check whether they match.
bool all_match = true;
for (int i = 0; i < cgraph->n_nodes; i++) {
bool has_matching_properties = ggml_graph_node_has_matching_properties(
cgraph->nodes[i], &graph_ptr->ggml_graph_properties[i]);
if(!has_matching_properties) {
all_match = false;
break;
}
}
if (all_match) {
lru_cache.matched_graph = graph_ptr;
return false;
}
}
return false;
lru_cache.matched_graph = nullptr;
return true;
}
#endif // USE_ACL_GRAPH

Expand All @@ -2244,14 +2262,13 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
bool & use_cann_graph, bool & cann_graph_update_required) {
#ifdef USE_ACL_GRAPH
if (use_cann_graph && cann_graph_update_required) {
if (cann_ctx->cann_graph->graph != nullptr) {
ACL_CHECK(aclmdlRIDestroy(cann_ctx->cann_graph->graph));
cann_ctx->cann_graph->graph = nullptr;
if (cann_ctx->graph_lru_cache.matched_graph->graph != nullptr) {
ACL_CHECK(aclmdlRIDestroy(cann_ctx->graph_lru_cache.matched_graph->graph));
cann_ctx->graph_lru_cache.matched_graph->graph = nullptr;
}
ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL));
}
#endif // USE_ACL_GRAPH

// Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
// With the use of CANN graphs, the execution will be performed by the graph launch.
if (!use_cann_graph || cann_graph_update_required) {
Expand All @@ -2272,12 +2289,12 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx

#ifdef USE_ACL_GRAPH
if (use_cann_graph && cann_graph_update_required) { // End CANN graph capture
ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &cann_ctx->cann_graph->graph));
ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &cann_ctx->graph_lru_cache.matched_graph->graph));
}

if (use_cann_graph) {
// Execute graph
ACL_CHECK(aclmdlRIExecuteAsync(cann_ctx->cann_graph->graph, cann_ctx->stream()));
ACL_CHECK(aclmdlRIExecuteAsync(cann_ctx->graph_lru_cache.matched_graph->graph, cann_ctx->stream()));
}
#endif // USE_ACL_GRAPH
}
Expand Down Expand Up @@ -2311,19 +2328,14 @@ static enum ggml_status ggml_backend_cann_graph_compute(
}

if (use_cann_graph) {
if (cann_ctx->cann_graph == nullptr) {
cann_ctx->cann_graph.reset(new ggml_cann_graph());
cann_graph_update_required = true;
}

// TODO: refactor to lru_cache
cann_graph_update_required = is_cann_graph_update_required(cann_ctx, cgraph);
set_ggml_graph_node_properties(cann_ctx, cgraph);
}
#else
bool use_cann_graph = false;
bool cann_graph_update_required = false;
#endif // USE_ACL_GRAPH

evaluate_and_capture_cann_graph(
cann_ctx,
cgraph,
Expand Down
Loading