Skip to content

Commit 193cdcd

Browse files
authored
Fix graph updating (#2857)
1 parent d8ceae7 commit 193cdcd

File tree

1 file changed

+23
-14
lines changed

1 file changed

+23
-14
lines changed

mlx/backend/cuda/device.cpp

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -332,21 +332,30 @@ bool is_graph_updatable(cudaGraph_t graph, int& cluster_dim_x) {
332332
for (const auto& node : nodes) {
333333
cudaGraphNodeType type;
334334
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
335-
if (type != cudaGraphNodeTypeKernel) {
336-
return false;
337-
}
338-
cudaLaunchAttributeValue cluster_dim;
339-
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
340-
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
341-
// Only dim.x can be greater than 1
342-
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
343-
return false;
344-
}
345-
// Only one child node allowed when subgraph uses clusters
346-
if (cluster_dim.clusterDim.x > 0 && num_nodes > 1) {
335+
if (type == cudaGraphNodeTypeGraph) {
336+
// Try to be updatable for a structure like graph -> graph -> kernel
337+
if (num_nodes > 1) {
338+
return false;
339+
}
340+
cudaGraph_t child;
341+
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
342+
return is_graph_updatable(child, cluster_dim_x);
343+
} else if (type != cudaGraphNodeTypeKernel) {
347344
return false;
345+
} else {
346+
cudaLaunchAttributeValue cluster_dim;
347+
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
348+
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
349+
// Only dim.x can be greater than 1
350+
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
351+
return false;
352+
}
353+
// Only one child node allowed when subgraph uses clusters
354+
if (cluster_dim.clusterDim.x > 0 && num_nodes > 1) {
355+
return false;
356+
}
357+
cluster_dim_x = cluster_dim.clusterDim.x;
348358
}
349-
cluster_dim_x = cluster_dim.clusterDim.x;
350359
}
351360
return true;
352361
}
@@ -362,7 +371,7 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
362371
}
363372
cudaGraphNode_t node;
364373
int cluster_dim_x = 0;
365-
is_graph_updatable_ = is_graph_updatable(child, cluster_dim_x);
374+
is_graph_updatable_ &= is_graph_updatable(child, cluster_dim_x);
366375
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
367376
insert_graph_dependencies(
368377
GraphNode{node, "G" + std::to_string(cluster_dim_x)});

0 commit comments

Comments
 (0)