Skip to content

Commit 7f0520c

Browse files
authored
bug fix to multi-cudagraph (#19856)
### Description <!-- Describe your changes. --> run_count_before_capture_ is graph_id aware, fix the bug by adding a map to retrieve the run_count_ for each graph_id. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 319159b commit 7f0520c

File tree

2 files changed

+18
-7
lines changed

2 files changed

+18
-7
lines changed

onnxruntime/core/providers/cuda/cuda_execution_provider.cc

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,13 @@ CUDAExecutionProvider::PerThreadContext::~PerThreadContext() {
194194

195195
bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptureAllowed(
196196
CudaGraphAnnotation_t cuda_graph_annotation_id) const {
197-
return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_ &&
198-
IsGraphCaptureAllowedOnRun(cuda_graph_annotation_id);
197+
if (!IsGraphCaptureAllowedOnRun(cuda_graph_annotation_id)) {
198+
return false;
199+
}
200+
if (graph_id_to_run_count_.find(cuda_graph_annotation_id) == graph_id_to_run_count_.end()) {
201+
return false;
202+
}
203+
return graph_id_to_run_count_.at(cuda_graph_annotation_id) >= min_num_runs_before_cuda_graph_capture_;
199204
}
200205

201206
bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptureAllowedOnRun(
@@ -234,8 +239,13 @@ Status CUDAExecutionProvider::PerThreadContext::ReplayGraph(CudaGraphAnnotation_
234239
return cuda_graph_.Replay(graph_annotation_id);
235240
}
236241

237-
void CUDAExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture() {
238-
++regular_run_count_before_graph_capture_;
242+
void CUDAExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture(
243+
CudaGraphAnnotation_t cuda_graph_annotation_id) {
244+
if (graph_id_to_run_count_.find(cuda_graph_annotation_id) == graph_id_to_run_count_.end()) {
245+
graph_id_to_run_count_[cuda_graph_annotation_id] = 1;
246+
return;
247+
}
248+
graph_id_to_run_count_[cuda_graph_annotation_id]++;
239249
}
240250

241251
void OverrideTunableOpInfoByEnv(CUDAExecutionProviderInfo& info) {
@@ -428,7 +438,7 @@ Status CUDAExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunO
428438
// so run the captured graph here to actually execute the work.
429439
ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(cuda_graph_annotation_id));
430440
} else {
431-
GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture();
441+
GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture(cuda_graph_annotation_id);
432442
}
433443
}
434444

onnxruntime/core/providers/cuda/cuda_execution_provider.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ class CUDAExecutionProvider : public IExecutionProvider {
175175
bool IsGraphCaptured(CudaGraphAnnotation_t cuda_graph_annotation_id) const;
176176
CudaGraphAnnotation_t GetCudaGraphAnnotationId(const onnxruntime::RunOptions& run_options) const;
177177
Status ReplayGraph(CudaGraphAnnotation_t cuda_graph_annotation_id);
178-
void IncrementRegularRunCountBeforeGraphCapture();
178+
void IncrementRegularRunCountBeforeGraphCapture(CudaGraphAnnotation_t cuda_graph_annotation_id);
179179

180180
private:
181181
cublasHandle_t cublas_handle_ = nullptr;
@@ -194,7 +194,8 @@ class CUDAExecutionProvider : public IExecutionProvider {
194194
// Cuda graph with multi threads will be supported in the future, so cuda_graph_
195195
// is put under PerThreadContext.
196196
CUDAGraph cuda_graph_;
197-
int regular_run_count_before_graph_capture_ = 0;
197+
// Map of graph id to regular_run_count_before_graph_capture
198+
std::unordered_map<CudaGraphAnnotation_t, int> graph_id_to_run_count_;
198199

199200
// There is chance that the second regular run allocates GPU memory for causes like:
200201
// (1) memory pattern is enabled. (2) arena allocation for stream.

0 commit comments

Comments
 (0)