@@ -43,8 +43,8 @@ CUptiResult (*)(uint32_t enable, CUpti_SubscriberHandle subscriber,
4343 CUpti_CallbackDomain domain, CUpti_CallbackId cbid);
4444
4545#define LOAD_CUPTI_SYM (p, lib, x ) \
46- p-> x = (cupti##x##_fn)ur_loader::LibLoader::getFunctionPtr(lib.get(), \
47- " cupti" #x);
46+ p. x = (cupti##x##_fn)ur_loader::LibLoader::getFunctionPtr(lib.get(), \
47+ " cupti" #x);
4848
4949#else
5050using tracing_event_t = void *;
@@ -55,15 +55,21 @@ using cuptiEnableDomain_fn = void *;
5555using cuptiEnableCallback_fn = void *;
5656#endif // XPTI_ENABLE_INSTRUMENTATION
5757
58+ struct cupti_table_t_ {
59+ cuptiSubscribe_fn Subscribe = nullptr ;
60+ cuptiUnsubscribe_fn Unsubscribe = nullptr ;
61+ cuptiEnableDomain_fn EnableDomain = nullptr ;
62+ cuptiEnableCallback_fn EnableCallback = nullptr ;
63+
64+ bool isInitialized () const ;
65+ };
66+
5867struct cuda_tracing_context_t_ {
5968 tracing_event_t CallEvent = nullptr ;
6069 tracing_event_t DebugEvent = nullptr ;
6170 subscriber_handle_t Subscriber = nullptr ;
6271 ur_loader::LibLoader::Lib Library;
63- cuptiSubscribe_fn Subscribe = nullptr ;
64- cuptiUnsubscribe_fn Unsubscribe = nullptr ;
65- cuptiEnableDomain_fn EnableDomain = nullptr ;
66- cuptiEnableCallback_fn EnableCallback = nullptr ;
72+ cupti_table_t_ Cupti;
6773};
6874
6975#ifdef XPTI_ENABLE_INSTRUMENTATION
@@ -132,6 +138,10 @@ void freeCUDATracingContext(cuda_tracing_context_t_ *Ctx) {
132138#endif // XPTI_ENABLE_INSTRUMENTATION
133139}
134140
141+ bool cupti_table_t_::isInitialized () const {
142+ return Subscribe && Unsubscribe && EnableDomain && EnableCallback;
143+ }
144+
135145bool loadCUDATracingLibrary (cuda_tracing_context_t_ *Ctx) {
136146#if defined(XPTI_ENABLE_INSTRUMENTATION) && defined(CUPTI_LIB_PATH)
137147 if (!Ctx)
@@ -141,16 +151,16 @@ bool loadCUDATracingLibrary(cuda_tracing_context_t_ *Ctx) {
141151 auto Lib{ur_loader::LibLoader::loadAdapterLibrary (CUPTI_LIB_PATH)};
142152 if (!Lib)
143153 return false ;
144- LOAD_CUPTI_SYM (Ctx, Lib, Subscribe)
145- LOAD_CUPTI_SYM (Ctx, Lib, Unsubscribe)
146- LOAD_CUPTI_SYM (Ctx, Lib, EnableDomain)
147- LOAD_CUPTI_SYM (Ctx, Lib, EnableCallback)
148- if (!Ctx->Subscribe || !Ctx->Unsubscribe || !Ctx->EnableDomain ||
149- !Ctx->EnableCallback ) {
150- unloadCUDATracingLibrary (Ctx);
154+ cupti_table_t_ Table;
155+ LOAD_CUPTI_SYM (Table, Lib, Subscribe)
156+ LOAD_CUPTI_SYM (Table, Lib, Unsubscribe)
157+ LOAD_CUPTI_SYM (Table, Lib, EnableDomain)
158+ LOAD_CUPTI_SYM (Table, Lib, EnableCallback)
159+ if (!Table.isInitialized ()) {
151160 return false ;
152161 }
153162 Ctx->Library = std::move (Lib);
163+ Ctx->Cupti = Table;
154164 return true ;
155165#else
156166 (void )Ctx;
@@ -160,14 +170,10 @@ bool loadCUDATracingLibrary(cuda_tracing_context_t_ *Ctx) {
160170
161171void unloadCUDATracingLibrary (cuda_tracing_context_t_ *Ctx) {
162172#ifdef XPTI_ENABLE_INSTRUMENTATION
163- if (!Ctx || !Ctx-> Library )
173+ if (!Ctx)
164174 return ;
165- Ctx->Subscribe = nullptr ;
166- Ctx->Unsubscribe = nullptr ;
167- Ctx->EnableDomain = nullptr ;
168- Ctx->EnableCallback = nullptr ;
169-
170175 Ctx->Library .reset ();
176+ Ctx->Cupti = cupti_table_t_ ();
171177#else
172178 (void )Ctx;
173179#endif // XPTI_ENABLE_INSTRUMENTATION
@@ -207,12 +213,12 @@ void enableCUDATracing(cuda_tracing_context_t_ *Ctx) {
207213 xptiMakeEvent (" CUDA Plugin Debug Layer" , &CUDADebugPayload,
208214 xpti::trace_algorithm_event, xpti_at::active, &Dummy);
209215
210- Ctx->Subscribe (&Ctx->Subscriber , cuptiCallback, Ctx);
211- Ctx->EnableDomain (1 , Ctx->Subscriber , CUPTI_CB_DOMAIN_DRIVER_API);
212- Ctx->EnableCallback (0 , Ctx->Subscriber , CUPTI_CB_DOMAIN_DRIVER_API,
213- CUPTI_DRIVER_TRACE_CBID_cuGetErrorString);
214- Ctx->EnableCallback (0 , Ctx->Subscriber , CUPTI_CB_DOMAIN_DRIVER_API,
215- CUPTI_DRIVER_TRACE_CBID_cuGetErrorName);
216+ Ctx->Cupti . Subscribe (&Ctx->Subscriber , cuptiCallback, Ctx);
217+ Ctx->Cupti . EnableDomain (1 , Ctx->Subscriber , CUPTI_CB_DOMAIN_DRIVER_API);
218+ Ctx->Cupti . EnableCallback (0 , Ctx->Subscriber , CUPTI_CB_DOMAIN_DRIVER_API,
219+ CUPTI_DRIVER_TRACE_CBID_cuGetErrorString);
220+ Ctx->Cupti . EnableCallback (0 , Ctx->Subscriber , CUPTI_CB_DOMAIN_DRIVER_API,
221+ CUPTI_DRIVER_TRACE_CBID_cuGetErrorName);
216222#else
217223 (void )Ctx;
218224#endif
@@ -223,8 +229,8 @@ void disableCUDATracing(cuda_tracing_context_t_ *Ctx) {
223229 if (!Ctx || !xptiTraceEnabled ())
224230 return ;
225231
226- if (Ctx->Subscriber ) {
227- Ctx->Unsubscribe (Ctx->Subscriber );
232+ if (Ctx->Subscriber && Ctx-> Cupti . isInitialized () ) {
233+ Ctx->Cupti . Unsubscribe (Ctx->Subscriber );
228234 Ctx->Subscriber = nullptr ;
229235 }
230236
0 commit comments