Skip to content

Commit 001c1d4

Browse files
committed
[CANN]Support Acl Graph
Signed-off-by: noemotiovon <[email protected]>
1 parent 79c137f commit 001c1d4

File tree

2 files changed

+413
-14
lines changed

2 files changed

+413
-14
lines changed

ggml/src/ggml-cann/common.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,58 @@ class cann_task_queue {
333333
int32_t device_;
334334
};
335335

336+
// TODO: 删除 //
337+
// #if defined(GGML_CANN_USE_GRAPHS)
338+
#define USE_CANN_GRAPH
339+
// #endif
340+
341+
struct ggml_graph_node_properties {
342+
void * node_address;
343+
ggml_op node_op;
344+
int64_t ne[GGML_MAX_DIMS];
345+
size_t nb[GGML_MAX_DIMS];
346+
void * src_address[GGML_MAX_SRC];
347+
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
348+
};
349+
350+
struct ggml_cann_graph {
351+
#ifdef USE_CANN_GRAPH
352+
~ggml_cann_graph() {
353+
// if (instance != nullptr) {
354+
// aclmdlRIDestroy(instance);
355+
// }
356+
if (graph != nullptr) {
357+
aclmdlRIDestroy(graph);
358+
}
359+
}
360+
361+
aclmdlRI graph = nullptr;
362+
// aclmdlRI instance = nullptr;
363+
364+
size_t num_nodes = 0;
365+
366+
// std::vector<aclGraphNode*> nodes;
367+
std::vector<aclopAttr*> op_attrs;
368+
std::vector<aclTensor*> input_tensors;
369+
std::vector<aclTensor*> output_tensors;
370+
371+
// bool disable_due_to_npu_arch = false;
372+
bool disable_due_to_too_many_updates = false;
373+
// bool disable_due_to_failed_graph_capture = false;
374+
int number_consecutive_updates = 0;
375+
376+
std::vector<ggml_graph_node_properties> ggml_graph_properties;
377+
378+
// TODO: user cpy indirection
379+
// bool use_cpy_indirection = false;
380+
// std::vector<char *> cpy_dest_ptrs;
381+
// char ** dest_ptrs_d = nullptr;
382+
// int dest_ptrs_size = 0;
383+
384+
int graph_cpynode_index = -1;
385+
#endif // USE_CANN_GRAPH
386+
};
387+
336388
/**
337389
* @brief Context for managing CANN backend operations.
338390
*/
@@ -341,6 +393,7 @@ struct ggml_backend_cann_context {
341393
std::string name; /**< Name of the device. */
342394
std::string description; /**< Description of the device. */
343395
aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
396+
std::unique_ptr<ggml_cann_graph> cann_graph;
344397
cann_task_queue task_queue;
345398
bool async_mode;
346399

0 commit comments

Comments
 (0)