Skip to content

Commit 0e73651

Browse files
committed
[CANN]Support ACL GRAPH2
Signed-off-by: noemotiovon <[email protected]>
1 parent 001c1d4 commit 0e73651

File tree

3 files changed

+107
-170
lines changed

3 files changed

+107
-170
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,16 @@ static void cann_copy(ggml_backend_cann_context& ctx, aclTensor* acl_src,
714714
void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
715715
ggml_tensor* src0 = dst->src[0];
716716

717+
// ---- TODO: ----
718+
// if (ctx.cann_graph && ctx.cann_graph->use_cpy_indirection) {
719+
// int &idx = ctx.cann_graph->graph_cpynode_index; // 引用,方便自增
720+
// if (idx < ctx.cann_graph->dest_ptrs_size && ctx.cann_graph->dest_ptrs_d != nullptr) {
721+
// dst->data = ggml_cann_get_dynamic_dst_ptr(ctx, idx);
722+
// idx++;
723+
// } else {
724+
// std::cerr << "Warning: graph_cpynode_index out of range or dest_ptrs_d null!" << std::endl;
725+
// }
726+
// }
717727
aclTensor* acl_src = ggml_cann_create_tensor(src0);
718728
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
719729
if (ggml_are_same_shape(src0, dst)) {

ggml/src/ggml-cann/common.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -376,10 +376,10 @@ struct ggml_cann_graph {
376376
std::vector<ggml_graph_node_properties> ggml_graph_properties;
377377

378378
// TODO: user cpy indirection
379-
// bool use_cpy_indirection = false;
380-
// std::vector<char *> cpy_dest_ptrs;
381-
// char ** dest_ptrs_d = nullptr;
382-
// int dest_ptrs_size = 0;
379+
bool use_cpy_indirection = false;
380+
std::vector<char *> cpy_dest_ptrs;
381+
char ** dest_ptrs_d = nullptr;
382+
int dest_ptrs_size = 0;
383383

384384
int graph_cpynode_index = -1;
385385
#endif // USE_CANN_GRAPH

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 93 additions & 166 deletions
Original file line numberDiff line numberDiff line change
@@ -1945,10 +1945,51 @@ static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
19451945

19461946
#ifdef USE_CANN_GRAPH
19471947

1948+
void ggml_cann_cpy_dest_ptrs_copy(ggml_cann_graph * cann_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, aclrtStream stream) {
1949+
#ifdef USE_CANN_GRAPH
1950+
if (host_dest_ptrs_size == 0 || host_dest_ptrs == nullptr) {
1951+
return;
1952+
}
1953+
1954+
if (cann_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate NPU memory for destination pointers
1955+
ACL_CHECK(aclrtSynchronizeStream(stream));
1956+
if (cann_graph->dest_ptrs_d != nullptr) {
1957+
std::cout << "lcg ggml_cann_cpy_dest_ptrs_copy0" << std::endl;
1958+
ACL_CHECK(aclrtFree(cann_graph->dest_ptrs_d));
1959+
}
1960+
// TODO: check ACL_MEM_MALLOC_NORMAL_ONLY
1961+
char * device_ptr = nullptr;
1962+
ACL_CHECK(aclrtMalloc((void **)&device_ptr, host_dest_ptrs_size*sizeof(char *), ACL_MEM_MALLOC_NORMAL_ONLY));
1963+
cann_graph->dest_ptrs_d = (char **) device_ptr;
1964+
cann_graph->dest_ptrs_size = host_dest_ptrs_size;
1965+
}
1966+
// copy destination pointers to NPU
1967+
// cann_graph->dest_ptrs_d (Device Pointer, 存在 device 上)
1968+
//
1969+
//
1970+
//
1971+
// +---------------------+
1972+
// | 0x600000000 | <-- device copy address1
1973+
// +---------------------+
1974+
// | 0x600010000 | <-- device copy address2
1975+
// +---------------------+
1976+
// | 0x600020000 | <-- device copy address3
1977+
// +---------------------+
1978+
ACL_CHECK(aclrtMemcpy(cann_graph->dest_ptrs_d, host_dest_ptrs_size*sizeof(char *), host_dest_ptrs, host_dest_ptrs_size*sizeof(char *),
1979+
ACL_MEMCPY_HOST_TO_DEVICE));
1980+
1981+
ACL_CHECK(aclrtSynchronizeStream(stream));
1982+
cann_graph->graph_cpynode_index = 0; // reset index
1983+
#else
1984+
GGML_UNUSED(cann_graph); GGML_UNUSED(host_dest_ptrs);
1985+
GGML_UNUSED(host_dest_ptrs_size); GGML_UNUSED(stream);
1986+
#endif
1987+
}
1988+
19481989
static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph,
19491990
bool use_cann_graph) {
19501991
// Loop over nodes in GGML graph to obtain info needed for CANN graph
1951-
// cann_ctx->cann_graph->cpy_dest_ptrs.clear();
1992+
cann_ctx->cann_graph->cpy_dest_ptrs.clear();
19521993

19531994
for (int i = 0; i < cgraph->n_nodes; i++) {
19541995
ggml_tensor * node = cgraph->nodes[i];
@@ -1958,26 +1999,38 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_can
19581999
}
19592000

19602001
if (node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) {
2002+
// std::cout << "lcg check_node_graph_compatibility_and_refresh_copy_ops start1:" << std::endl;
19612003
use_cann_graph = false; // This node type is not supported by CANN graph capture
19622004
#ifndef NDEBUG
19632005
GGML_LOG_DEBUG("%s: disabling CANN graphs due to unsupported node type\n", __func__);
19642006
#endif
19652007
}
19662008

1967-
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
1968-
// disable CANN graphs for batch size > 1 for now.
1969-
// Changes in batch size or context size can cause changes to the grid size of some kernels.
1970-
use_cann_graph = false;
1971-
#ifndef NDEBUG
1972-
GGML_LOG_DEBUG("%s: disabling CANN graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
1973-
#endif
2009+
// if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
2010+
// // disable CANN graphs for batch size > 1 for now.
2011+
// // Changes in batch size or context size can cause changes to the grid size of some kernels.
2012+
// use_cann_graph = false;
2013+
// #ifndef NDEBUG
2014+
// GGML_LOG_DEBUG("%s: disabling CANN graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
2015+
// #endif
2016+
// }
2017+
if (node->op == GGML_OP_CPY) {
2018+
// Store the pointers which are updated for each token, such that these can be sent
2019+
// to the device and accessed using indirection from CANN graph
2020+
cann_ctx->cann_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data);
19742021
}
19752022

19762023
if (!use_cann_graph) {
19772024
break;
19782025
}
19792026
}
19802027

2028+
if (use_cann_graph) {
2029+
cann_ctx->cann_graph->use_cpy_indirection = true;
2030+
// copy pointers to NPU so they can be accessed via indirection within CANN graph
2031+
ggml_cann_cpy_dest_ptrs_copy(cann_ctx->cann_graph.get(), cann_ctx->cann_graph->cpy_dest_ptrs.data(), cann_ctx->cann_graph->cpy_dest_ptrs.size(), cann_ctx->stream());
2032+
}
2033+
19812034
return use_cann_graph;
19822035
}
19832036

@@ -1994,53 +2047,17 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p
19942047
memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS);
19952048
}
19962049

1997-
// static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
1998-
// if (node->data != graph_node_properties->node_address &&
1999-
// // TODO: node->op != GGML_OP_CPY &&
2000-
// node->op != GGML_OP_VIEW) {
2001-
// bool stdout = node->data != graph_node_properties->node_address;
2002-
// std::cout << "lcg is_cann_graph_update_required start3:" << stdout <<std::endl;
2003-
// return false;
2004-
// }
2005-
2006-
// if (node->op != graph_node_properties->node_op) {
2007-
// return false;
2008-
// }
2009-
2010-
// for (int i = 0; i < GGML_MAX_DIMS; i++) {
2011-
// if (node->ne[i] != graph_node_properties->ne[i]) {
2012-
// return false;
2013-
// }
2014-
// if (node->nb[i] != graph_node_properties->nb[i]) {
2015-
// return false;
2016-
// }
2017-
// }
2018-
2019-
// for (int i = 0; i < GGML_MAX_SRC; i++) {
2020-
// if (node->src[i] &&
2021-
// node->src[i]->data != graph_node_properties->src_address[i] &&
2022-
// // TODO: node->op != GGML_OP_CPY &&
2023-
// node->op != GGML_OP_VIEW
2024-
// ) {
2025-
// return false;
2026-
// }
2027-
// }
2028-
2029-
// if (node->op == GGML_OP_SCALE &&
2030-
// memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
2031-
// return false;
2032-
// }
2033-
2034-
// return true;
2035-
// }
2036-
20372050
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
2038-
// 1. 检查操作类型
2051+
if (node->data != graph_node_properties->node_address &&
2052+
node->op != GGML_OP_CPY &&
2053+
node->op != GGML_OP_VIEW) {
2054+
return false;
2055+
}
2056+
20392057
if (node->op != graph_node_properties->node_op) {
20402058
return false;
20412059
}
20422060

2043-
// 2. 检查 shape 和 stride
20442061
for (int i = 0; i < GGML_MAX_DIMS; i++) {
20452062
if (node->ne[i] != graph_node_properties->ne[i]) {
20462063
return false;
@@ -2050,18 +2067,21 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
20502067
}
20512068
}
20522069

2053-
// 3. 可选:检查输入数量是否相同
2054-
// (如果你想进一步严格匹配,可以加入输入数量和逻辑检查)
2070+
for (int i = 0; i < GGML_MAX_SRC; i++) {
2071+
if (node->src[i] &&
2072+
node->src[i]->data != graph_node_properties->src_address[i] &&
2073+
node->op != GGML_OP_CPY &&
2074+
node->op != GGML_OP_VIEW
2075+
) {
2076+
return false;
2077+
}
2078+
}
20552079

2056-
// 4. 检查 op 参数(针对特定 op)
20572080
if (node->op == GGML_OP_SCALE &&
20582081
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
20592082
return false;
20602083
}
20612084

2062-
// 5. 不再检查 data 地址
2063-
// (为了避免因为 allocator/memory pool 导致的地址变化)
2064-
20652085
return true;
20662086
}
20672087

@@ -2134,7 +2154,7 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
21342154
}
21352155
ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL));
21362156
}
2137-
#endif // USE_CUDA_GRAPH
2157+
#endif // USE_CANN_GRAPH
21382158

21392159
// Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
21402160
// With the use of CANN graphs, the execution will be performed by the graph launch.
@@ -2168,94 +2188,9 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
21682188
ACL_CHECK(aclmdlRIExecuteAsync(cann_ctx->cann_graph->graph, cann_ctx->stream()));
21692189
#else
21702190
graph_evaluated_or_captured = true;
2171-
#endif // USE_CUDA_GRAPH
2172-
}
2173-
}
2174-
2175-
// /**
2176-
// * @brief Computes a computational graph using a CANN backend.
2177-
// *
2178-
// * This function computes the operations defined in the computational graph
2179-
// * using the specified CANN backend.
2180-
// *
2181-
// * @param backend Pointer to the CANN backend structure to use for computation.
2182-
// * @param cgraph Pointer to the computational graph structure containing nodes
2183-
// * representing operations to be computed.
2184-
// * @return enum ggml_status Returns GGML_STATUS_SUCCESS if computation
2185-
// * completes successfully, otherwise an appropriate error status.
2186-
// */
2187-
// static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
2188-
// ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *)backend->context;
2189-
2190-
// ggml_cann_set_device(cann_ctx->device);
2191-
2192-
// #ifdef USE_CANN_GRAPH
2193-
// static const bool disable_cann_graphs_due_to_env = (getenv("GGML_CANN_DISABLE_GRAPHS") != nullptr);
2194-
2195-
// // Objects required for CANN Graph
2196-
// if (cann_ctx->cann_graph == nullptr) {
2197-
// cann_ctx->cann_graph.reset(new ggml_cann_graph());
2198-
// }
2199-
2200-
// bool use_cann_graph = true;
2201-
// bool cann_graph_update_required = false;
2202-
2203-
// if (cann_ctx->cann_graph->graph == nullptr) {
2204-
// // TODO: Add not support device deal
2205-
// if (ggml_cann_info().devices[cann_ctx->device].cc < 0) {
2206-
// cann_ctx->cann_graph->disable_due_to_npu_arch = true;
2207-
// #ifndef NDEBUG
2208-
// GGML_LOG_DEBUG("%s: disabling CANN graphs due to CANN architecture\n", __func__);
2209-
// #endif
2210-
// }
2211-
// }
2212-
2213-
// // Disable CANN graphs in presence of env var, old NPU, use-case which is changing too rapidly,
2214-
// // or previous graph capture failure.
2215-
// // Also disable for multi-npu for now. TO DO investigate
2216-
// if (disable_cann_graphs_due_to_env
2217-
// || cann_ctx->cann_graph->disable_due_to_npu_arch
2218-
// || cann_ctx->cann_graph->disable_due_to_too_many_updates
2219-
// || cann_ctx->cann_graph->disable_due_to_failed_graph_capture) {
2220-
// use_cann_graph = false;
2221-
// }
2222-
2223-
// if (use_cann_graph) {
2224-
// cann_graph_update_required = is_cann_graph_update_required(cann_ctx, cgraph);
2225-
2226-
// use_cann_graph = check_node_graph_compatibility_and_refresh_copy_ops(cann_ctx, cgraph, cann_graph_update_required);
2227-
2228-
// // Disable CANN graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
2229-
// if (use_cann_graph && cann_graph_update_required) {
2230-
// cann_ctx->cann_graph->number_consecutive_updates++;
2231-
// } else {
2232-
// cann_ctx->cann_graph->number_consecutive_updates = 0;
2233-
// }
2234-
2235-
// if (cann_ctx->cann_graph->number_consecutive_updates >= 4) {
2236-
// cann_ctx->cann_graph->disable_due_to_too_many_updates = true;
2237-
// #ifndef NDEBUG
2238-
// GGML_LOG_DEBUG("%s: disabling CANN graphs due to too many consecutive updates\n", __func__);
2239-
// #endif
2240-
// }
2241-
// }
2242-
2243-
// if (use_cann_graph && cann_graph_update_required) { // Start CANN graph capture
2244-
// ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL));
2245-
// }
2246-
2247-
// #else
2248-
// bool use_cann_graph = false;
2249-
// bool cann_graph_update_required = false;
2250-
// #endif // USE_CANN_GRAPH
2251-
2252-
// bool graph_evaluated_or_captured = false;
2253-
2254-
// evaluate_and_capture_cann_graph(cann_ctx, cgraph, graph_evaluated_or_captured, use_cann_graph, cann_graph_update_required);
2255-
2256-
// return GGML_STATUS_SUCCESS;
2257-
// }
2258-
2191+
#endif // USE_CANN_GRAPH
2192+
}
2193+
}
22592194

22602195
/**
22612196
* @brief Computes a computational graph using a CANN backend.
@@ -2271,34 +2206,26 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
22712206
*/
22722207
static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
22732208
ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *)backend->context;
2274-
22752209
ggml_cann_set_device(cann_ctx->device);
22762210

22772211
#ifdef USE_CANN_GRAPH
2278-
static const bool disable_cann_graphs_due_to_env = (getenv("GGML_CANN_DISABLE_GRAPHS") != nullptr);
2279-
22802212
bool use_cann_graph = true;
2281-
bool cann_graph_update_required = false;
2213+
bool cann_graph_update_required = true;
2214+
2215+
static const bool disable_cann_graphs_due_to_env = (getenv("GGML_CANN_DISABLE_GRAPHS") != nullptr);
2216+
if (disable_cann_graphs_due_to_env) {
2217+
use_cann_graph = false;
2218+
}
22822219

22832220
// Objects required for CANN Graph
22842221
if (cann_ctx->cann_graph == nullptr) {
22852222
cann_ctx->cann_graph.reset(new ggml_cann_graph());
2286-
cann_graph_update_required = true;
2287-
}
2288-
2289-
// Disable CANN graphs in presence of env var, old NPU, use-case which is changing too rapidly,
2290-
// or previous graph capture failure.
2291-
// Also disable for multi-npu for now. TO DO investigate
2292-
if (disable_cann_graphs_due_to_env
2293-
// || cann_ctx->cann_graph->disable_due_to_too_many_updates
2294-
) {
2295-
use_cann_graph = false;
22962223
}
22972224

22982225
if (use_cann_graph) {
22992226
cann_graph_update_required = is_cann_graph_update_required(cann_ctx, cgraph);
23002227

2301-
// use_cann_graph = check_node_graph_compatibility_and_refresh_copy_ops(cann_ctx, cgraph, cann_graph_update_required);
2228+
use_cann_graph = check_node_graph_compatibility_and_refresh_copy_ops(cann_ctx, cgraph, use_cann_graph);
23022229

23032230
// Disable CANN graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
23042231
if (use_cann_graph && cann_graph_update_required) {
@@ -2307,12 +2234,12 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
23072234
cann_ctx->cann_graph->number_consecutive_updates = 0;
23082235
}
23092236

2310-
if (cann_ctx->cann_graph->number_consecutive_updates >= 4) {
2311-
cann_ctx->cann_graph->disable_due_to_too_many_updates = true;
2312-
#ifndef NDEBUG
2313-
GGML_LOG_DEBUG("%s: disabling CANN graphs due to too many consecutive updates\n", __func__);
2314-
#endif
2315-
}
2237+
// if (cann_ctx->cann_graph->number_consecutive_updates >= 4) {
2238+
// cann_ctx->cann_graph->disable_due_to_too_many_updates = true;
2239+
// #ifndef NDEBUG
2240+
// GGML_LOG_DEBUG("%s: disabling CANN graphs due to too many consecutive updates\n", __func__);
2241+
// #endif
2242+
// }
23162243
}
23172244

23182245
#else

0 commit comments

Comments
 (0)