Skip to content

Commit a05f2b8

Browse files
committed
Disable CUDA Graph for MUSA
1 parent db400a6 commit a05f2b8

File tree

4 files changed

+8
-11
lines changed

4 files changed

+8
-11
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ struct ggml_tensor_extra_gpu {
678678
};
679679

680680

681-
#if (defined(GGML_CUDA_USE_GRAPHS)) || defined(GGML_HIP_GRAPHS)
681+
#if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS))
682682
#define USE_CUDA_GRAPH
683683
#endif
684684

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2610,14 +2610,14 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx,
26102610

26112611
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
26122612

2613-
#if (CUDART_VERSION < 12000 || defined(__HIP_PLATFORM_AMD__))
2613+
#if CUDART_VERSION >= 12000
2614+
cudaGraphExecUpdateResultInfo result_info;
2615+
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
2616+
#else
26142617
cudaGraphNode_t errorNode;
26152618
cudaGraphExecUpdateResult result_info;
26162619
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
2617-
#else
2618-
cudaGraphExecUpdateResultInfo result_info;
2619-
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
2620-
#endif // (CUDART_VERSION < 12000 || defined(__HIP_PLATFORM_AMD__))
2620+
#endif // CUDART_VERSION >= 12000
26212621

26222622
if (stat == cudaErrorGraphExecUpdateFailure) {
26232623
#ifndef NDEBUG

ggml/src/ggml-cuda/vendors/musa.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@
119119
#define cudaGraphExecDestroy musaGraphExecDestroy
120120
#define cudaGraphExec_t musaGraphExec_t
121121
#define cudaGraphExecUpdate musaGraphExecUpdate
122-
#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
122+
#define cudaGraphExecUpdateResult musaGraphExecUpdateResult
123123
#define cudaGraphGetNodes musaGraphGetNodes
124124
#define cudaGraphInstantiate musaGraphInstantiate
125125
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
@@ -132,6 +132,7 @@
132132
#define cudaGraph_t musaGraph_t
133133
#define cudaKernelNodeParams musaKernelNodeParams
134134
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
135+
#define cudaStreamBeginCapture musaStreamBeginCapture
135136
#define cudaStreamEndCapture musaStreamEndCapture
136137

137138
typedef mt_bfloat16 nv_bfloat16;

ggml/src/ggml-musa/CMakeLists.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,6 @@ if (MUSAToolkit_FOUND)
6767
add_compile_definitions(GGML_USE_MUSA)
6868
add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
6969

70-
if (GGML_CUDA_GRAPHS)
71-
add_compile_definitions(GGML_CUDA_USE_GRAPHS)
72-
endif()
73-
7470
if (GGML_CUDA_FORCE_MMQ)
7571
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
7672
endif()

0 commit comments

Comments
 (0)