Skip to content

Commit 79d3f3d

Browse files
committed
cann : refactor ACL graph cache
Move the graph property checking code into methods of LRU cache. Signed-off-by: Wang Weixuan <[email protected]>
1 parent 09d2e0d commit 79d3f3d

File tree

2 files changed

+174
-181
lines changed

2 files changed

+174
-181
lines changed

ggml/src/ggml-cann/common.h

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,60 @@ struct ggml_graph_node_properties {
354354
// op
355355
ggml_op node_op;
356356
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
357+
358+
/**
359+
* @brief Check if a ggml tensor node matches this property set.
360+
*
361+
* This function compares all relevant fields (address, op type, shape, source inputs, op params)
362+
* to determine whether the current node matches these previously recorded properties.
363+
*
364+
* @param node The current ggml tensor node.
365+
* @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
366+
*/
367+
bool has_matching_properties(ggml_tensor * node) {
368+
if (node->data != this->node_address && node->op != GGML_OP_VIEW) {
369+
return false;
370+
}
371+
372+
if (node->op != this->node_op) {
373+
return false;
374+
}
375+
376+
for (int i = 0; i < GGML_MAX_DIMS; i++) {
377+
if (node->ne[i] != this->ne[i]) {
378+
return false;
379+
}
380+
if (node->nb[i] != this->nb[i]) {
381+
return false;
382+
}
383+
}
384+
385+
for (int i = 0; i < GGML_MAX_SRC; i++) {
386+
if (node->src[i]) {
387+
if (node->src[i]->data != this->src_address[i] && node->op != GGML_OP_VIEW) {
388+
return false;
389+
}
390+
391+
for (int d = 0; d < GGML_MAX_DIMS; d++) {
392+
if (node->src[i]->ne[d] != this->src_ne[i][d]) {
393+
return false;
394+
}
395+
if (node->src[i]->nb[d] != this->src_nb[i][d]) {
396+
return false;
397+
}
398+
}
399+
} else {
400+
if (this->src_address[i] != nullptr) {
401+
return false;
402+
}
403+
}
404+
}
405+
406+
if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) {
407+
return memcmp(this->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
408+
}
409+
return true;
410+
}
357411
};
358412

359413
struct ggml_cann_graph {
@@ -366,6 +420,79 @@ struct ggml_cann_graph {
366420
aclmdlRI graph = nullptr;
367421

368422
std::vector<ggml_graph_node_properties> ggml_graph_properties;
423+
424+
/**
425+
* @brief Create a new CANN graph from a ggml computation graph.
426+
*
427+
* This function creates a new ggml_cann_graph object and fills its node properties
428+
* (operation type, dimensions, strides, input sources, and operation parameters)
429+
* based on the current ggml computation graph.
430+
*
431+
* Each node in the ggml graph is mapped to a property entry in the new CANN graph:
432+
* - node address
433+
* - operation type
434+
* - shape (ne) and strides (nb)
435+
* - source tensor addresses
436+
* - operation parameters
437+
*
438+
* @param cgraph The current ggml computation graph.
439+
* @return Pointer to the newly created ggml_cann_graph object.
440+
*/
441+
static ggml_cann_graph * create_from_cgraph(ggml_cgraph * cgraph) {
442+
ggml_cann_graph * new_graph = new ggml_cann_graph();
443+
new_graph->ggml_graph_properties.resize(cgraph->n_nodes);
444+
445+
for (int node_idx = 0; node_idx < cgraph->n_nodes; ++node_idx) {
446+
ggml_tensor * node = cgraph->nodes[node_idx];
447+
auto & prop = new_graph->ggml_graph_properties[node_idx];
448+
449+
prop.node_address = node->data;
450+
prop.node_op = node->op;
451+
452+
std::copy_n(node->ne, GGML_MAX_DIMS, prop.ne);
453+
std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);
454+
455+
for (int src = 0; src < GGML_MAX_SRC; ++src) {
456+
if (node->src[src]) {
457+
prop.src_address[src] = node->src[src]->data;
458+
std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]);
459+
std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]);
460+
} else {
461+
prop.src_address[src] = nullptr;
462+
std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0);
463+
std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0);
464+
}
465+
}
466+
467+
memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
468+
}
469+
470+
return new_graph;
471+
}
472+
473+
/**
474+
* @brief Check whether this CANN graph matches the given ggml computation graph.
475+
*
476+
* This function compares the number of nodes and each node's properties
477+
* (operation type, dimensions, strides, inputs, and operation parameters)
478+
* to determine whether this CANN graph matches the given ggml graph.
479+
*
480+
* @param cgraph The current ggml computation graph.
481+
* @return true if this CANN graph matches the ggml graph; false otherwise.
482+
*/
483+
bool matches_cgraph(ggml_cgraph * cgraph) {
484+
if (this->ggml_graph_properties.size() != static_cast<size_t>(cgraph->n_nodes)) {
485+
return false;
486+
}
487+
488+
for (int i = 0; i < cgraph->n_nodes; ++i) {
489+
if (!this->ggml_graph_properties[i].has_matching_properties(cgraph->nodes[i])) {
490+
return false;
491+
}
492+
}
493+
494+
return true;
495+
}
369496
};
370497

371498
/**
@@ -400,6 +527,7 @@ struct ggml_cann_graph_lru_cache {
400527
}
401528

402529
/**
530+
<<<<<<< HEAD
403531
* @brief Move an existing graph to the front of the cache.
404532
* @param node Pointer to the ggml_cann_graph to move.
405533
*/
@@ -409,6 +537,8 @@ struct ggml_cann_graph_lru_cache {
409537
}
410538

411539
/**
540+
=======
541+
>>>>>>> 0b15bb221 (cann : refactor ACL graph cache)
412542
* @brief Clear all graphs from the cache (also frees memory).
413543
*/
414544
void clear() {
@@ -421,8 +551,33 @@ struct ggml_cann_graph_lru_cache {
421551
/**
422552
* @brief Destructor that clears the cache and frees all cached graphs.
423553
*/
554+
<<<<<<< HEAD
424555
~ggml_cann_graph_lru_cache() {
425556
clear();
557+
=======
558+
~ggml_cann_graph_lru_cache() { clear(); }
559+
560+
/**
561+
* @brief Find a cached CANN graph that matches the given ggml graph and move it to front.
562+
*
563+
* This function iterates through the cached CANN graphs stored in the LRU cache and
564+
* compares them against the given ggml computation graph. If a matching graph is found,
565+
* it is promoted to the front of the LRU cache and returned. Otherwise, the function
566+
* returns nullptr.
567+
*
568+
* @param cgraph The current ggml computation graph.
569+
* @return true if found; false otherwise.
570+
*/
571+
bool find_and_move_to_front(ggml_cgraph * cgraph) {
572+
for (auto & graph_ptr : this->cache_list) {
573+
if (graph_ptr->matches_cgraph(cgraph)) {
574+
cache_list.remove(graph_ptr);
575+
cache_list.push_front(graph_ptr);
576+
return true;
577+
}
578+
}
579+
return false;
580+
>>>>>>> 0b15bb221 (cann : refactor ACL graph cache)
426581
}
427582
};
428583
#endif // USE_ACL_GRAPH

0 commit comments

Comments
 (0)