@@ -332,21 +332,30 @@ bool is_graph_updatable(cudaGraph_t graph, int& cluster_dim_x) {
332332 for (const auto & node : nodes) {
333333 cudaGraphNodeType type;
334334 CHECK_CUDA_ERROR (cudaGraphNodeGetType (node, &type));
335- if (type != cudaGraphNodeTypeKernel) {
336- return false ;
337- }
338- cudaLaunchAttributeValue cluster_dim;
339- CHECK_CUDA_ERROR (cudaGraphKernelNodeGetAttribute (
340- node, cudaLaunchAttributeClusterDimension, &cluster_dim));
341- // Only dim.x can be greater than 1
342- if (cluster_dim.clusterDim .y > 1 || cluster_dim.clusterDim .z > 1 ) {
343- return false ;
344- }
345- // Only one child node allowed when subgraph uses clusters
346- if (cluster_dim.clusterDim .x > 0 && num_nodes > 1 ) {
335+ if (type == cudaGraphNodeTypeGraph) {
336+ // Try to be updatable for a structure like graph -> graph -> kernel
337+ if (num_nodes > 1 ) {
338+ return false ;
339+ }
340+ cudaGraph_t child;
341+ CHECK_CUDA_ERROR (cudaGraphChildGraphNodeGetGraph (node, &child));
342+ return is_graph_updatable (child, cluster_dim_x);
343+ } else if (type != cudaGraphNodeTypeKernel) {
347344 return false ;
345+ } else {
346+ cudaLaunchAttributeValue cluster_dim;
347+ CHECK_CUDA_ERROR (cudaGraphKernelNodeGetAttribute (
348+ node, cudaLaunchAttributeClusterDimension, &cluster_dim));
349+ // Only dim.x can be greater than 1
350+ if (cluster_dim.clusterDim .y > 1 || cluster_dim.clusterDim .z > 1 ) {
351+ return false ;
352+ }
353+ // Only one child node allowed when subgraph uses clusters
354+ if (cluster_dim.clusterDim .x > 0 && num_nodes > 1 ) {
355+ return false ;
356+ }
357+ cluster_dim_x = cluster_dim.clusterDim .x ;
348358 }
349- cluster_dim_x = cluster_dim.clusterDim .x ;
350359 }
351360 return true ;
352361}
@@ -362,7 +371,7 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
362371 }
363372 cudaGraphNode_t node;
364373 int cluster_dim_x = 0 ;
365- is_graph_updatable_ = is_graph_updatable (child, cluster_dim_x);
374+ is_graph_updatable_ & = is_graph_updatable (child, cluster_dim_x);
366375 CHECK_CUDA_ERROR (cudaGraphAddChildGraphNode (&node, graph_, NULL , 0 , child));
367376 insert_graph_dependencies (
368377 GraphNode{node, " G" + std::to_string (cluster_dim_x)});
0 commit comments