Skip to content

Commit 6853c5e

Browse files
authored
Pdlp fix batch cuda graph (#68)
This PR aims at fixing the invalid operation while there is a graph capture we sometimes see when using batch solve. Solution is to use a regular instead of a non-blocking stream to make sure that if any operation (like a cudaFree from Thrust) is being launched on the default stream, it will wait for all other operations on other stream to finish first, preventing any cudaMalloc/Free while another stream might be doing a CUDA Graph capture.
1 parent 09aa8c9 commit 6853c5e

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

cpp/include/cuopt/linear_programming/utilities/cython_solve.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ struct solver_ret_t {
103103
// Wrapper for solve to expose the API to cython.
104104

105105
std::unique_ptr<solver_ret_t> call_solve(cuopt::mps_parser::data_model_view_t<int, double>*,
106-
linear_programming::solver_settings_t<int, double>*);
106+
linear_programming::solver_settings_t<int, double>*,
107+
unsigned int flags = cudaStreamNonBlocking);
107108

108109
std::pair<std::vector<std::unique_ptr<solver_ret_t>>, double> call_batch_solve(
109110
std::vector<cuopt::mps_parser::data_model_view_t<int, double>*>,

cpp/src/linear_programming/utilities/cython_solve.cu

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,12 +208,13 @@ mip_ret_t call_solve_mip(
208208

209209
std::unique_ptr<solver_ret_t> call_solve(
210210
cuopt::mps_parser::data_model_view_t<int, double>* data_model,
211-
cuopt::linear_programming::solver_settings_t<int, double>* solver_settings)
211+
cuopt::linear_programming::solver_settings_t<int, double>* solver_settings,
212+
unsigned int flags)
212213
{
213214
raft::common::nvtx::range fun_scope("Call Solve");
214215

215216
cudaStream_t stream;
216-
RAFT_CUDA_TRY(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
217+
RAFT_CUDA_TRY(cudaStreamCreateWithFlags(&stream, flags));
217218
const raft::handle_t handle_{stream};
218219

219220
auto op_problem = data_model_to_optimization_problem(data_model, solver_settings, &handle_);
@@ -283,9 +284,11 @@ std::pair<std::vector<std::unique_ptr<solver_ret_t>>, double> call_batch_solve(
283284
solver_settings->set_parameter(CUOPT_METHOD, CUOPT_METHOD_PDLP);
284285
}
285286

287+
// Use a default stream instead of a non-blocking to avoid invalid operations while some CUDA
288+
// Graph might be capturing in another stream
286289
#pragma omp parallel for num_threads(max_thread)
287290
for (std::size_t i = 0; i < size; ++i)
288-
list[i] = std::move(call_solve(data_models[i], solver_settings));
291+
list[i] = std::move(call_solve(data_models[i], solver_settings, cudaStreamDefault));
289292

290293
auto end = std::chrono::high_resolution_clock::now();
291294
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start_solver);

0 commit comments

Comments
 (0)