Skip to content

Commit 72b2e05

Browse files
noemotiovonwangweixuan
authored andcommitted
CANN: Improve ACL graph matching (ggml-org#16166)
* CANN: improve ACL graph matching Record `ne` and `nb` information for src tensors and include them in the graph matching check. This enhances the robustness of ACL graph matching by preventing incorrect matches when src tensors share the same data address but differ in shape or stride. * CANN: add op_params match
1 parent abb20d3 commit 72b2e05

File tree

2 files changed

+45
-12
lines changed

2 files changed

+45
-12
lines changed

ggml/src/ggml-cann/common.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,11 +341,18 @@ class cann_task_queue {
341341

342342
#ifdef USE_ACL_GRAPH
343343
struct ggml_graph_node_properties {
344+
// dst tensor
344345
void * node_address;
345-
ggml_op node_op;
346346
int64_t ne[GGML_MAX_DIMS];
347347
size_t nb[GGML_MAX_DIMS];
348+
349+
// src tensor
348350
void * src_address[GGML_MAX_SRC];
351+
int64_t src_ne[GGML_MAX_SRC][GGML_MAX_DIMS];
352+
size_t src_nb[GGML_MAX_SRC][GGML_MAX_DIMS];
353+
354+
// op
355+
ggml_op node_op;
349356
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
350357
};
351358

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2133,7 +2133,15 @@ static void add_lru_matched_graph_node_properties(
21332133
std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);
21342134

21352135
for (int src = 0; src < GGML_MAX_SRC; ++src) {
2136-
prop.src_address[src] = node->src[src] ? node->src[src]->data : nullptr;
2136+
if (node->src[src]) {
2137+
prop.src_address[src] = node->src[src]->data;
2138+
std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]);
2139+
std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]);
2140+
} else {
2141+
prop.src_address[src] = nullptr;
2142+
std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0);
2143+
std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0);
2144+
}
21372145
}
21382146

21392147
memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
@@ -2153,14 +2161,18 @@ static void add_lru_matched_graph_node_properties(
21532161
* @param graph_node_properties The stored properties of a CANN graph node.
21542162
* @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
21552163
*/
2156-
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
2164+
static bool ggml_graph_node_has_matching_properties(
2165+
ggml_tensor * node,
2166+
ggml_graph_node_properties * graph_node_properties) {
21572167
if (node->data != graph_node_properties->node_address &&
2158-
node->op != GGML_OP_VIEW) {
2168+
node->op != GGML_OP_VIEW) {
21592169
return false;
21602170
}
2171+
21612172
if (node->op != graph_node_properties->node_op) {
21622173
return false;
21632174
}
2175+
21642176
for (int i = 0; i < GGML_MAX_DIMS; i++) {
21652177
if (node->ne[i] != graph_node_properties->ne[i]) {
21662178
return false;
@@ -2169,17 +2181,31 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
21692181
return false;
21702182
}
21712183
}
2184+
21722185
for (int i = 0; i < GGML_MAX_SRC; i++) {
2173-
if (node->src[i] &&
2174-
node->src[i]->data != graph_node_properties->src_address[i] &&
2175-
node->op != GGML_OP_VIEW
2176-
) {
2177-
return false;
2186+
if (node->src[i]) {
2187+
if (node->src[i]->data != graph_node_properties->src_address[i] &&
2188+
node->op != GGML_OP_VIEW) {
2189+
return false;
2190+
}
2191+
2192+
for (int d = 0; d < GGML_MAX_DIMS; d++) {
2193+
if (node->src[i]->ne[d] != graph_node_properties->src_ne[i][d]) {
2194+
return false;
2195+
}
2196+
if (node->src[i]->nb[d] != graph_node_properties->src_nb[i][d]) {
2197+
return false;
2198+
}
2199+
}
2200+
} else {
2201+
if (graph_node_properties->src_address[i] != nullptr) {
2202+
return false;
2203+
}
21782204
}
21792205
}
2180-
if (node->op == GGML_OP_SCALE &&
2181-
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
2182-
return false;
2206+
2207+
if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) {
2208+
return memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
21832209
}
21842210
return true;
21852211
}

0 commit comments

Comments
 (0)