@@ -2480,6 +2480,7 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p
24802480 for (int i = 0 ; i < GGML_MAX_SRC; i++) {
24812481 graph_node_properties->src_address [i] = node->src [i] ? node->src [i]->data : nullptr ;
24822482 }
2483+ memcpy (graph_node_properties->op_params , node->op_params , GGML_MAX_OP_PARAMS);
24832484}
24842485
24852486static bool ggml_graph_node_has_matching_properties (ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
@@ -2511,6 +2512,12 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
25112512 return false ;
25122513 }
25132514 }
2515+
2516+ if (node->op == GGML_OP_SCALE &&
2517+ memcmp (graph_node_properties->op_params , node->op_params , GGML_MAX_OP_PARAMS) != 0 ) {
2518+ return false ;
2519+ }
2520+
25142521 return true ;
25152522}
25162523
@@ -2721,7 +2728,9 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
27212728 // First call with null argument gets number of nodes in graph
27222729 CUDA_CHECK (cudaGraphGetNodes (cuda_ctx->cuda_graph ->graph , nullptr , &cuda_ctx->cuda_graph ->num_nodes ));
27232730 // Subsequent call with non-null argument gets nodes
2731+ cuda_ctx->cuda_graph ->nodes .clear ();
27242732 cuda_ctx->cuda_graph ->nodes .resize (cuda_ctx->cuda_graph ->num_nodes );
2733+ cuda_ctx->cuda_graph ->params .clear ();
27252734 cuda_ctx->cuda_graph ->params .resize (cuda_ctx->cuda_graph ->num_nodes );
27262735 if (cuda_ctx->cuda_graph ->num_nodes > 0 ) {
27272736 CUDA_CHECK (cudaGraphGetNodes (cuda_ctx->cuda_graph ->graph , cuda_ctx->cuda_graph ->nodes .data (), &cuda_ctx->cuda_graph ->num_nodes ));
0 commit comments