@@ -2133,7 +2133,15 @@ static void add_lru_matched_graph_node_properties(
21332133 std::copy_n (node->nb , GGML_MAX_DIMS, prop.nb );
21342134
21352135 for (int src = 0 ; src < GGML_MAX_SRC; ++src) {
2136- prop.src_address [src] = node->src [src] ? node->src [src]->data : nullptr ;
2136+ if (node->src [src]) {
2137+ prop.src_address [src] = node->src [src]->data ;
2138+ std::copy_n (node->src [src]->ne , GGML_MAX_DIMS, prop.src_ne [src]);
2139+ std::copy_n (node->src [src]->nb , GGML_MAX_DIMS, prop.src_nb [src]);
2140+ } else {
2141+ prop.src_address [src] = nullptr ;
2142+ std::fill_n (prop.src_ne [src], GGML_MAX_DIMS, 0 );
2143+ std::fill_n (prop.src_nb [src], GGML_MAX_DIMS, 0 );
2144+ }
21372145 }
21382146
21392147 memcpy (prop.op_params , node->op_params , GGML_MAX_OP_PARAMS);
@@ -2153,14 +2161,18 @@ static void add_lru_matched_graph_node_properties(
21532161 * @param graph_node_properties The stored properties of a CANN graph node.
21542162 * @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
21552163 */
2156- static bool ggml_graph_node_has_matching_properties (ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
2164+ static bool ggml_graph_node_has_matching_properties (
2165+ ggml_tensor * node,
2166+ ggml_graph_node_properties * graph_node_properties) {
21572167 if (node->data != graph_node_properties->node_address &&
2158- node->op != GGML_OP_VIEW) {
2168+ node->op != GGML_OP_VIEW) {
21592169 return false ;
21602170 }
2171+
21612172 if (node->op != graph_node_properties->node_op ) {
21622173 return false ;
21632174 }
2175+
21642176 for (int i = 0 ; i < GGML_MAX_DIMS; i++) {
21652177 if (node->ne [i] != graph_node_properties->ne [i]) {
21662178 return false ;
@@ -2169,17 +2181,31 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
21692181 return false ;
21702182 }
21712183 }
2184+
21722185 for (int i = 0 ; i < GGML_MAX_SRC; i++) {
2173- if (node->src [i] &&
2174- node->src [i]->data != graph_node_properties->src_address [i] &&
2175- node->op != GGML_OP_VIEW
2176- ) {
2177- return false ;
2186+ if (node->src [i]) {
2187+ if (node->src [i]->data != graph_node_properties->src_address [i] &&
2188+ node->op != GGML_OP_VIEW) {
2189+ return false ;
2190+ }
2191+
2192+ for (int d = 0 ; d < GGML_MAX_DIMS; d++) {
2193+ if (node->src [i]->ne [d] != graph_node_properties->src_ne [i][d]) {
2194+ return false ;
2195+ }
2196+ if (node->src [i]->nb [d] != graph_node_properties->src_nb [i][d]) {
2197+ return false ;
2198+ }
2199+ }
2200+ } else {
2201+ if (graph_node_properties->src_address [i] != nullptr ) {
2202+ return false ;
2203+ }
21782204 }
21792205 }
2180- if (node-> op == GGML_OP_SCALE &&
2181- memcmp (graph_node_properties-> op_params , node->op_params , GGML_MAX_OP_PARAMS) != 0 ) {
2182- return false ;
2206+
2207+ if (node-> op == GGML_OP_SCALE || node-> op == GGML_OP_UNARY || node->op == GGML_OP_GLU ) {
2208+ return memcmp (graph_node_properties-> op_params , node-> op_params , GGML_MAX_OP_PARAMS) == 0 ;
21832209 }
21842210 return true ;
21852211}
0 commit comments