2323#include " dual_simplex/tic_toc.hpp"
2424
2525#include < cuda_runtime.h>
26+ #include < utilities/driver_helpers.cuh>
2627
2728#include < raft/common/nvtx.hpp>
2829
29- #include " cuda.h"
3030#include " cudss.h"
3131
3232namespace cuopt ::linear_programming::dual_simplex {
@@ -157,37 +157,46 @@ class sparse_cholesky_cudss_t : public sparse_cholesky_base_t<i_t, f_t> {
157157 cudssGetProperty (PATCH_LEVEL, &patch);
158158 settings.log .printf (" cuDSS Version : %d.%d.%d\n " , major, minor, patch);
159159
160- CU_CHECK (cuDriverGetVersion (&driver_version));
161- settings_.log .printf (" CUDA Driver Version : %d\n " , driver_version);
162-
163160 cuda_error = cudaSuccess;
164161 status = CUDSS_STATUS_SUCCESS;
165162
166- if (settings_. concurrent_halt != nullptr && driver_version >= 13000 ) {
167- # if defined(SPLIT_SM_FOR_BARRIER) && CUDART_VERSION >= 13000
163+ if (CUDART_VERSION >= 13000 && settings_. concurrent_halt != nullptr ) {
164+ cuGetErrorString_func = cuopt::detail::get_driver_entry_point ( " cuGetErrorString " );
168165 // 1. Set up the GPU resources
169166 CUdevResource initial_device_GPU_resources = {};
170- CU_CHECK (cuDeviceGetDevResource (
171- handle_ptr_->get_device (), &initial_device_GPU_resources, CU_DEV_RESOURCE_TYPE_SM));
167+ auto cuDeviceGetDevResource_func =
168+ cuopt::detail::get_driver_entry_point (" cuDeviceGetDevResource" );
169+ CU_CHECK (reinterpret_cast <decltype (::cuDeviceGetDevResource)*>(cuDeviceGetDevResource_func)(
170+ handle_ptr_->get_device (), &initial_device_GPU_resources, CU_DEV_RESOURCE_TYPE_SM),
171+ reinterpret_cast <decltype (::cuGetErrorString)*>(cuGetErrorString_func));
172+
172173#ifdef DEBUG
173- std::cout << " Initial GPU resources retrieved via cuDeviceGetDevResource() have type "
174- << initial_device_GPU_resources.type << " and SM count "
175- << initial_device_GPU_resources.sm .smCount << std::endl;
174+ settings.log .printf (
175+ " Initial GPU resources retrieved via "
176+ " cuDeviceGetDevResource() have type "
177+ " %d and SM count %d\n " ,
178+ initial_device_GPU_resources.type ,
179+ initial_device_GPU_resources.sm .smCount );
176180#endif
177181
178182 // 2. Partition the GPU resources
179183 auto total_SMs = initial_device_GPU_resources.sm .smCount ;
180184 auto barrier_sms = raft::alignTo (static_cast <i_t >(total_SMs * 0 .75f ), 8 );
181- CUdevResource input;
182185 CUdevResource resource;
186+ auto cuDevSmResourceSplitByCount_func =
187+ cuopt::detail::get_driver_entry_point (" cuDevSmResourceSplitByCount" );
183188 auto n_groups = 1u ;
184189 auto use_flags = CU_DEV_SM_RESOURCE_SPLIT_IGNORE_SM_COSCHEDULING; // or 0
185- CU_CHECK (cuDevSmResourceSplitByCount (
186- &resource, &n_groups, &initial_device_GPU_resources, nullptr , use_flags, barrier_sms));
190+ CU_CHECK (
191+ reinterpret_cast <decltype (::cuDevSmResourceSplitByCount)*>(
192+ cuDevSmResourceSplitByCount_func)(
193+ &resource, &n_groups, &initial_device_GPU_resources, nullptr , use_flags, barrier_sms),
194+ reinterpret_cast <decltype (::cuGetErrorString)*>(cuGetErrorString_func));
187195#ifdef DEBUG
188- printf (
189- " Resources were split into %d resource groups (had requested %d) with %d SMs each (had "
190- " requested %d)\n " ,
196+ settings.log .printf (
197+ " Resources were split into %d resource groups (had "
198+ " requested %d) with %d SMs each (had "
199+ " requested % d)\n " ,
191200 n_groups,
192201 n_groups,
193202 resource.sm .smCount ,
@@ -196,34 +205,42 @@ class sparse_cholesky_cudss_t : public sparse_cholesky_base_t<i_t, f_t> {
196205 // 3. Create the resource descriptor
197206 auto constexpr const n_resource_desc = 1 ;
198207 CUdevResourceDesc resource_desc;
199- CU_CHECK (cuDevResourceGenerateDesc (&resource_desc, &resource, n_resource_desc));
208+ auto cuDevResourceGenerateDesc_func =
209+ cuopt::detail::get_driver_entry_point (" cuDevResourceGenerateDesc" );
210+ CU_CHECK (reinterpret_cast <decltype (::cuDevResourceGenerateDesc)*>(
211+ cuDevResourceGenerateDesc_func)(&resource_desc, &resource, n_resource_desc),
212+ reinterpret_cast <decltype (::cuGetErrorString)*>(cuGetErrorString_func));
200213#ifdef DEBUG
201- printf (
202- " For the resource descriptor of barrier green context we will combine %d resources of "
203- " %d "
204- " SMs each\n " ,
214+ settings. log . printf (
215+ " For the resource descriptor of barrier green context "
216+ " we will combine %d resources of "
217+ " %d SMs each\n " ,
205218 n_resource_desc,
206219 resource.sm .smCount );
207220#endif
208221
209222 // Only perform this if CUDA version is more than 13
210- // (all resource splitting and descriptor creation already above)
211- // No additional code needed here as the logic is already guarded above.
212- // 4. Create the green context and stream for that green context
213- // CUstream barrier_green_ctx_stream;
223+ // (all resource splitting and descriptor creation already
224+ // above) No additional code needed here as the logic is
225+ // already guarded above.
226+ // 4. Create the green context and stream for that green
227+ // context CUstream barrier_green_ctx_stream;
214228 i_t stream_priority;
215229 cudaStream_t cuda_stream = handle_ptr_->get_stream ();
216230 cudaError_t priority_result = cudaStreamGetPriority (cuda_stream, &stream_priority);
217231 RAFT_CUDA_TRY (priority_result);
218- CU_CHECK (cuGreenCtxCreate (
219- &barrier_green_ctx, resource_desc, handle_ptr_->get_device (), CU_GREEN_CTX_DEFAULT_STREAM));
220- CU_CHECK (cuGreenCtxStreamCreate (
221- &stream, barrier_green_ctx, CU_STREAM_NON_BLOCKING, stream_priority));
222- #endif
223- } else {
224- // Convert runtime API stream to driver API stream for consistency
225- cudaStream_t cuda_stream = handle_ptr_->get_stream ();
226- stream = reinterpret_cast <CUstream>(cuda_stream);
232+ auto cuGreenCtxCreate_func = cuopt::detail::get_driver_entry_point (" cuGreenCtxCreate" );
233+ CU_CHECK (reinterpret_cast <decltype (::cuGreenCtxCreate)*>(cuGreenCtxCreate_func)(
234+ &barrier_green_ctx,
235+ resource_desc,
236+ handle_ptr_->get_device (),
237+ CU_GREEN_CTX_DEFAULT_STREAM),
238+ reinterpret_cast <decltype (::cuGetErrorString)*>(cuGetErrorString_func));
239+ auto cuGreenCtxStreamCreate_func =
240+ cuopt::detail::get_driver_entry_point (" cuGreenCtxStreamCreate" );
241+ CU_CHECK (reinterpret_cast <decltype (::cuGreenCtxStreamCreate)*>(cuGreenCtxStreamCreate_func)(
242+ &stream, barrier_green_ctx, CU_STREAM_NON_BLOCKING, stream_priority),
243+ reinterpret_cast <decltype (::cuGetErrorString)*>(cuGetErrorString_func));
227244 }
228245
229246 CUDSS_CALL_AND_CHECK_EXIT (cudssCreate (&handle), status, " cudssCreate" );
@@ -336,12 +353,15 @@ class sparse_cholesky_cudss_t : public sparse_cholesky_base_t<i_t, f_t> {
336353 CUDSS_CALL_AND_CHECK_EXIT (cudssConfigDestroy (solverConfig), status, " cudssConfigDestroy" );
337354 CUDSS_CALL_AND_CHECK_EXIT (cudssDestroy (handle), status, " cudssDestroy" );
338355 CUDA_CALL_AND_CHECK_EXIT (cudaStreamSynchronize (stream), " cudaStreamSynchronize" );
339- #ifdef SPLIT_SM_FOR_BARRIER
340- if (settings_.concurrent_halt != nullptr && driver_version >= 13000 ) {
341- CU_CHECK (cuStreamDestroy (stream));
342356#if CUDART_VERSION >= 13000
343- CU_CHECK (cuGreenCtxDestroy (barrier_green_ctx));
344- #endif
357+ if (settings_.concurrent_halt != nullptr ) {
358+ auto cuStreamDestroy_func = cuopt::detail::get_driver_entry_point (" cuStreamDestroy" );
359+ CU_CHECK (reinterpret_cast <decltype (::cuStreamDestroy)*>(cuStreamDestroy_func)(stream),
360+ reinterpret_cast <decltype (::cuGetErrorString)*>(cuGetErrorString_func));
361+ auto cuGreenCtxDestroy_func = cuopt::detail::get_driver_entry_point (" cuGreenCtxDestroy" );
362+ CU_CHECK (
363+ reinterpret_cast <decltype (::cuGreenCtxDestroy)*>(cuGreenCtxDestroy_func)(barrier_green_ctx),
364+ reinterpret_cast <decltype (::cuGetErrorString)*>(cuGetErrorString_func));
345365 handle_ptr_->get_stream ().synchronize ();
346366 }
347367#endif
@@ -473,7 +493,7 @@ class sparse_cholesky_cudss_t : public sparse_cholesky_base_t<i_t, f_t> {
473493
474494 auto d_nnz = Arow.row_start .element (Arow.m , Arow.row_start .stream ());
475495 if (nnz != d_nnz) {
476- printf (" Error: nnz %d != A_in.col_start[A_in.n] %d\n " , nnz, d_nnz);
496+ settings_. log . printf (" Error: nnz %d != A_in.col_start[A_in.n] %d\n " , nnz, d_nnz);
477497 exit (1 );
478498 }
479499
@@ -796,11 +816,11 @@ class sparse_cholesky_cudss_t : public sparse_cholesky_base_t<i_t, f_t> {
796816 f_t * csr_values_d;
797817 f_t * x_values_d;
798818 f_t * b_values_d;
799- i_t driver_version;
800819
801820 const simplex_solver_settings_t <i_t , f_t >& settings_;
802821 CUgreenCtx barrier_green_ctx;
803822 CUstream stream;
823+ void * cuGetErrorString_func;
804824};
805825
806826} // namespace cuopt::linear_programming::dual_simplex
0 commit comments