Skip to content

Commit c406a52

Browse files
committed
cann : refactor ACL graph cache
Move the graph property checking code into methods of LRU cache. Signed-off-by: Wang Weixuan <[email protected]>
1 parent 09d2e0d commit c406a52

File tree

2 files changed

+168
-192
lines changed

2 files changed

+168
-192
lines changed

ggml/src/ggml-cann/common.h

Lines changed: 149 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,60 @@ struct ggml_graph_node_properties {
354354
// op
355355
ggml_op node_op;
356356
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
357+
358+
/**
359+
* @brief Check if a ggml tensor node matches this property set.
360+
*
361+
* This function compares all relevant fields (address, op type, shape, source inputs, op params)
362+
* to determine whether the current node matches these previously recorded properties.
363+
*
364+
* @param node The current ggml tensor node.
365+
* @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
366+
*/
367+
bool has_matching_properties(ggml_tensor * node) {
368+
if (node->data != this->node_address && node->op != GGML_OP_VIEW) {
369+
return false;
370+
}
371+
372+
if (node->op != this->node_op) {
373+
return false;
374+
}
375+
376+
for (int i = 0; i < GGML_MAX_DIMS; i++) {
377+
if (node->ne[i] != this->ne[i]) {
378+
return false;
379+
}
380+
if (node->nb[i] != this->nb[i]) {
381+
return false;
382+
}
383+
}
384+
385+
for (int i = 0; i < GGML_MAX_SRC; i++) {
386+
if (node->src[i]) {
387+
if (node->src[i]->data != this->src_address[i] && node->op != GGML_OP_VIEW) {
388+
return false;
389+
}
390+
391+
for (int d = 0; d < GGML_MAX_DIMS; d++) {
392+
if (node->src[i]->ne[d] != this->src_ne[i][d]) {
393+
return false;
394+
}
395+
if (node->src[i]->nb[d] != this->src_nb[i][d]) {
396+
return false;
397+
}
398+
}
399+
} else {
400+
if (this->src_address[i] != nullptr) {
401+
return false;
402+
}
403+
}
404+
}
405+
406+
if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) {
407+
return memcmp(this->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
408+
}
409+
return true;
410+
}
357411
};
358412

359413
struct ggml_cann_graph {
@@ -366,6 +420,79 @@ struct ggml_cann_graph {
366420
aclmdlRI graph = nullptr;
367421

368422
std::vector<ggml_graph_node_properties> ggml_graph_properties;
423+
424+
/**
425+
* @brief Create a new CANN graph from a ggml computation graph.
426+
*
427+
* This function creates a new ggml_cann_graph object and fills its node properties
428+
* (operation type, dimensions, strides, input sources, and operation parameters)
429+
* based on the current ggml computation graph.
430+
*
431+
* Each node in the ggml graph is mapped to a property entry in the new CANN graph:
432+
* - node address
433+
* - operation type
434+
* - shape (ne) and strides (nb)
435+
* - source tensor addresses
436+
* - operation parameters
437+
*
438+
* @param cgraph The current ggml computation graph.
439+
* @return Pointer to the newly created ggml_cann_graph object.
440+
*/
441+
static ggml_cann_graph * create_from_cgraph(ggml_cgraph * cgraph) {
442+
ggml_cann_graph * new_graph = new ggml_cann_graph();
443+
new_graph->ggml_graph_properties.resize(cgraph->n_nodes);
444+
445+
for (int node_idx = 0; node_idx < cgraph->n_nodes; ++node_idx) {
446+
ggml_tensor * node = cgraph->nodes[node_idx];
447+
auto & prop = new_graph->ggml_graph_properties[node_idx];
448+
449+
prop.node_address = node->data;
450+
prop.node_op = node->op;
451+
452+
std::copy_n(node->ne, GGML_MAX_DIMS, prop.ne);
453+
std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);
454+
455+
for (int src = 0; src < GGML_MAX_SRC; ++src) {
456+
if (node->src[src]) {
457+
prop.src_address[src] = node->src[src]->data;
458+
std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]);
459+
std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]);
460+
} else {
461+
prop.src_address[src] = nullptr;
462+
std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0);
463+
std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0);
464+
}
465+
}
466+
467+
memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
468+
}
469+
470+
return new_graph;
471+
}
472+
473+
/**
474+
* @brief Check whether this CANN graph matches the given ggml computation graph.
475+
*
476+
* This function compares the number of nodes and each node's properties
477+
* (operation type, dimensions, strides, inputs, and operation parameters)
478+
* to determine whether this CANN graph matches the given ggml graph.
479+
*
480+
* @param cgraph The current ggml computation graph.
481+
* @return true if this CANN graph matches the ggml graph; false otherwise.
482+
*/
483+
bool matches_cgraph(ggml_cgraph * cgraph) {
484+
if (this->ggml_graph_properties.size() != static_cast<size_t>(cgraph->n_nodes)) {
485+
return false;
486+
}
487+
488+
for (int i = 0; i < cgraph->n_nodes; ++i) {
489+
if (!this->ggml_graph_properties[i].has_matching_properties(cgraph->nodes[i])) {
490+
return false;
491+
}
492+
}
493+
494+
return true;
495+
}
369496
};
370497

371498
/**
@@ -399,15 +526,6 @@ struct ggml_cann_graph_lru_cache {
399526
cache_list.push_front(new_node);
400527
}
401528

402-
/**
403-
* @brief Move an existing graph to the front of the cache.
404-
* @param node Pointer to the ggml_cann_graph to move.
405-
*/
406-
void move_to_front(ggml_cann_graph* node) {
407-
cache_list.remove(node);
408-
cache_list.push_front(node);
409-
}
410-
411529
/**
412530
* @brief Clear all graphs from the cache (also frees memory).
413531
*/
@@ -421,8 +539,28 @@ struct ggml_cann_graph_lru_cache {
421539
/**
422540
* @brief Destructor that clears the cache and frees all cached graphs.
423541
*/
424-
~ggml_cann_graph_lru_cache() {
425-
clear();
542+
~ggml_cann_graph_lru_cache() { clear(); }
543+
544+
/**
545+
* @brief Find a cached CANN graph that matches the given ggml graph and move it to front.
546+
*
547+
* This function iterates through the cached CANN graphs stored in the LRU cache and
548+
* compares them against the given ggml computation graph. If a matching graph is found,
549+
* it is promoted to the front of the LRU cache and returned. Otherwise, the function
550+
* returns nullptr.
551+
*
552+
* @param cgraph The current ggml computation graph.
553+
* @return true if found; false otherwise.
554+
*/
555+
bool find_and_move_to_front(ggml_cgraph * cgraph) {
556+
for (auto & graph_ptr : this->cache_list) {
557+
if (graph_ptr->matches_cgraph(cgraph)) {
558+
cache_list.remove(graph_ptr);
559+
cache_list.push_front(graph_ptr);
560+
return true;
561+
}
562+
}
563+
return false;
426564
}
427565
};
428566
#endif // USE_ACL_GRAPH

0 commit comments

Comments
 (0)