1616#include < memory>
1717
1818namespace ur ::cuda {
19- std::shared_ptr<ur_adapter_handle_t_> adapter;
19+ ur_adapter_handle_t adapter;
2020} // namespace ur::cuda
2121
2222class ur_legacy_sink : public logger ::Sink {
@@ -47,22 +47,25 @@ ur_adapter_handle_t_::ur_adapter_handle_t_()
4747 std::getenv (" UR_SUPPRESS_ERROR_MESSAGE" ) != nullptr )) {
4848 logger.setLegacySink (std::make_unique<ur_legacy_sink>());
4949 }
50+
51+ TracingCtx = createCUDATracingContext ();
52+ enableCUDATracing (TracingCtx);
53+ }
54+
55+ ur_adapter_handle_t_::~ur_adapter_handle_t_ () {
56+ disableCUDATracing (TracingCtx);
57+ freeCUDATracingContext (TracingCtx);
5058}
5159
5260UR_APIEXPORT ur_result_t UR_APICALL
5361urAdapterGet (uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
5462 uint32_t *pNumAdapters) {
5563 if (NumEntries > 0 && phAdapters) {
5664 static std::once_flag InitFlag;
57- std::call_once (InitFlag, [=]() {
58- ur::cuda::adapter = std::make_shared<ur_adapter_handle_t_>();
59- });
60-
61- std::lock_guard<std::mutex> Lock{ur::cuda::adapter->Mutex };
62- ur::cuda::adapter->TracingCtx = createCUDATracingContext ();
63- enableCUDATracing (ur::cuda::adapter->TracingCtx );
65+ std::call_once (InitFlag,
66+ [=]() { ur::cuda::adapter = new ur_adapter_handle_t_; });
6467
65- *phAdapters = ur::cuda::adapter. get () ;
68+ *phAdapters = ur::cuda::adapter;
6669 }
6770
6871 if (pNumAdapters) {
@@ -79,11 +82,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) {
7982}
8083
8184UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease (ur_adapter_handle_t ) {
82- std::lock_guard<std::mutex> Lock{ur::cuda::adapter->Mutex };
8385 if (--ur::cuda::adapter->RefCount == 0 ) {
84- disableCUDATracing (ur::cuda::adapter->TracingCtx );
85- freeCUDATracingContext (ur::cuda::adapter->TracingCtx );
86- ur::cuda::adapter.reset ();
86+ delete ur::cuda::adapter;
8787 }
8888 return UR_RESULT_SUCCESS;
8989}
0 commit comments