1818#include < cupti.h>
1919#endif // XPTI_ENABLE_INSTRUMENTATION
2020
21+ #include " tracing.hpp"
22+ #include " ur_lib_loader.hpp"
2123#include < exception>
2224#include < iostream>
2325
26+ #ifdef XPTI_ENABLE_INSTRUMENTATION
27+ using tracing_event_t = xpti_td *;
28+ using subscriber_handle_t = CUpti_SubscriberHandle;
29+
30+ using cuptiSubscribe_fn = CUPTIAPI
31+ CUptiResult (*)(CUpti_SubscriberHandle *subscriber, CUpti_CallbackFunc callback,
32+ void *userdata);
33+
34+ using cuptiUnsubscribe_fn = CUPTIAPI
35+ CUptiResult (*)(CUpti_SubscriberHandle subscriber);
36+
37+ using cuptiEnableDomain_fn = CUPTIAPI
38+ CUptiResult (*)(uint32_t enable, CUpti_SubscriberHandle subscriber,
39+ CUpti_CallbackDomain domain);
40+
41+ using cuptiEnableCallback_fn = CUPTIAPI
42+ CUptiResult (*)(uint32_t enable, CUpti_SubscriberHandle subscriber,
43+ CUpti_CallbackDomain domain, CUpti_CallbackId cbid);
44+
45+ #define LOAD_CUPTI_SYM (p, lib, x ) \
46+ p.x = (cupti##x##_fn)ur_loader::LibLoader::getFunctionPtr(lib.get(), \
47+ " cupti" #x);
48+
49+ #else
50+ using tracing_event_t = void *;
51+ using subscriber_handle_t = void *;
52+ using cuptiSubscribe_fn = void *;
53+ using cuptiUnsubscribe_fn = void *;
54+ using cuptiEnableDomain_fn = void *;
55+ using cuptiEnableCallback_fn = void *;
56+ #endif // XPTI_ENABLE_INSTRUMENTATION
57+
58+ struct cupti_table_t_ {
59+ cuptiSubscribe_fn Subscribe = nullptr ;
60+ cuptiUnsubscribe_fn Unsubscribe = nullptr ;
61+ cuptiEnableDomain_fn EnableDomain = nullptr ;
62+ cuptiEnableCallback_fn EnableCallback = nullptr ;
63+
64+ bool isInitialized () const ;
65+ };
66+
67+ struct cuda_tracing_context_t_ {
68+ tracing_event_t CallEvent = nullptr ;
69+ tracing_event_t DebugEvent = nullptr ;
70+ subscriber_handle_t Subscriber = nullptr ;
71+ ur_loader::LibLoader::Lib Library;
72+ cupti_table_t_ Cupti;
73+ };
74+
2475#ifdef XPTI_ENABLE_INSTRUMENTATION
2576constexpr auto CUDA_CALL_STREAM_NAME = " sycl.experimental.cuda.call" ;
2677constexpr auto CUDA_DEBUG_STREAM_NAME = " sycl.experimental.cuda.debug" ;
2778
2879thread_local uint64_t CallCorrelationID = 0 ;
2980thread_local uint64_t DebugCorrelationID = 0 ;
3081
31- static xpti_td *GCallEvent = nullptr ;
32- static xpti_td *GDebugEvent = nullptr ;
33-
3482constexpr auto GVerStr = " 0.1" ;
3583constexpr int GMajVer = 0 ;
3684constexpr int GMinVer = 1 ;
3785
38- static void cuptiCallback (void *, CUpti_CallbackDomain, CUpti_CallbackId CBID ,
39- const void *CBData) {
86+ static void cuptiCallback (void *UserData , CUpti_CallbackDomain,
87+ CUpti_CallbackId CBID, const void *CBData) {
4088 if (xptiTraceEnabled ()) {
4189 const auto *CBInfo = static_cast <const CUpti_CallbackData *>(CBData);
90+ cuda_tracing_context_t_ *Ctx =
91+ static_cast <cuda_tracing_context_t_ *>(UserData);
4292
4393 if (CBInfo->callbackSite == CUPTI_API_ENTER) {
4494 CallCorrelationID = xptiGetUniqueId ();
@@ -57,22 +107,95 @@ static void cuptiCallback(void *, CUpti_CallbackDomain, CUpti_CallbackId CBID,
57107 uint8_t CallStreamID = xptiRegisterStream (CUDA_CALL_STREAM_NAME);
58108 uint8_t DebugStreamID = xptiRegisterStream (CUDA_DEBUG_STREAM_NAME);
59109
60- xptiNotifySubscribers (CallStreamID, TraceType, GCallEvent , nullptr ,
110+ xptiNotifySubscribers (CallStreamID, TraceType, Ctx-> CallEvent , nullptr ,
61111 CallCorrelationID, FuncName);
62112
63113 xpti::function_with_args_t Payload{
64114 FuncID, FuncName, const_cast <void *>(CBInfo->functionParams ),
65115 CBInfo->functionReturnValue , CBInfo->context };
66- xptiNotifySubscribers (DebugStreamID, TraceTypeArgs, GDebugEvent, nullptr ,
67- DebugCorrelationID, &Payload);
116+ xptiNotifySubscribers (DebugStreamID, TraceTypeArgs, Ctx-> DebugEvent ,
117+ nullptr , DebugCorrelationID, &Payload);
68118 }
69119}
70120#endif
71121
122+ cuda_tracing_context_t_ *createCUDATracingContext () {
123+ #ifdef XPTI_ENABLE_INSTRUMENTATION
124+ if (!xptiTraceEnabled ())
125+ return nullptr ;
126+ return new cuda_tracing_context_t_;
127+ #else
128+ return nullptr ;
129+ #endif // XPTI_ENABLE_INSTRUMENTATION
130+ }
131+
132+ void freeCUDATracingContext (cuda_tracing_context_t_ *Ctx) {
133+ #ifdef XPTI_ENABLE_INSTRUMENTATION
134+ unloadCUDATracingLibrary (Ctx);
135+ delete Ctx;
136+ #else
137+ (void )Ctx;
138+ #endif // XPTI_ENABLE_INSTRUMENTATION
139+ }
140+
141+ bool cupti_table_t_::isInitialized () const {
142+ return Subscribe && Unsubscribe && EnableDomain && EnableCallback;
143+ }
144+
145+ bool loadCUDATracingLibrary (cuda_tracing_context_t_ *Ctx) {
146+ #if defined(XPTI_ENABLE_INSTRUMENTATION) && defined(CUPTI_LIB_PATH)
147+ if (!Ctx)
148+ return false ;
149+ if (Ctx->Library )
150+ return true ;
151+ auto Lib{ur_loader::LibLoader::loadAdapterLibrary (CUPTI_LIB_PATH)};
152+ if (!Lib)
153+ return false ;
154+ cupti_table_t_ Table;
155+ LOAD_CUPTI_SYM (Table, Lib, Subscribe)
156+ LOAD_CUPTI_SYM (Table, Lib, Unsubscribe)
157+ LOAD_CUPTI_SYM (Table, Lib, EnableDomain)
158+ LOAD_CUPTI_SYM (Table, Lib, EnableCallback)
159+ if (!Table.isInitialized ()) {
160+ return false ;
161+ }
162+ Ctx->Library = std::move (Lib);
163+ Ctx->Cupti = Table;
164+ return true ;
165+ #else
166+ (void )Ctx;
167+ return false ;
168+ #endif // XPTI_ENABLE_INSTRUMENTATION && CUPTI_LIB_PATH
169+ }
170+
171+ void unloadCUDATracingLibrary (cuda_tracing_context_t_ *Ctx) {
172+ #ifdef XPTI_ENABLE_INSTRUMENTATION
173+ if (!Ctx)
174+ return ;
175+ Ctx->Library .reset ();
176+ Ctx->Cupti = cupti_table_t_ ();
177+ #else
178+ (void )Ctx;
179+ #endif // XPTI_ENABLE_INSTRUMENTATION
180+ }
181+
72182void enableCUDATracing () {
73183#ifdef XPTI_ENABLE_INSTRUMENTATION
74184 if (!xptiTraceEnabled ())
75185 return ;
186+ static cuda_tracing_context_t_ *Ctx = nullptr ;
187+ if (!Ctx)
188+ Ctx = createCUDATracingContext ();
189+ enableCUDATracing (Ctx);
190+ #endif
191+ }
192+
193+ void enableCUDATracing (cuda_tracing_context_t_ *Ctx) {
194+ #ifdef XPTI_ENABLE_INSTRUMENTATION
195+ if (!Ctx || !xptiTraceEnabled ())
196+ return ;
197+ else if (!loadCUDATracingLibrary (Ctx))
198+ return ;
76199
77200 xptiRegisterStream (CUDA_CALL_STREAM_NAME);
78201 xptiInitialize (CUDA_CALL_STREAM_NAME, GMajVer, GMinVer, GVerStr);
@@ -81,31 +204,39 @@ void enableCUDATracing() {
81204
82205 uint64_t Dummy;
83206 xpti::payload_t CUDAPayload (" CUDA Plugin Layer" );
84- GCallEvent =
207+ Ctx-> CallEvent =
85208 xptiMakeEvent (" CUDA Plugin Layer" , &CUDAPayload,
86209 xpti::trace_algorithm_event, xpti_at::active, &Dummy);
87210
88211 xpti::payload_t CUDADebugPayload (" CUDA Plugin Debug Layer" );
89- GDebugEvent =
212+ Ctx-> DebugEvent =
90213 xptiMakeEvent (" CUDA Plugin Debug Layer" , &CUDADebugPayload,
91214 xpti::trace_algorithm_event, xpti_at::active, &Dummy);
92215
93- CUpti_SubscriberHandle Subscriber;
94- cuptiSubscribe (&Subscriber, cuptiCallback, nullptr );
95- cuptiEnableDomain (1 , Subscriber, CUPTI_CB_DOMAIN_DRIVER_API);
96- cuptiEnableCallback (0 , Subscriber, CUPTI_CB_DOMAIN_DRIVER_API,
97- CUPTI_DRIVER_TRACE_CBID_cuGetErrorString);
98- cuptiEnableCallback (0 , Subscriber, CUPTI_CB_DOMAIN_DRIVER_API,
99- CUPTI_DRIVER_TRACE_CBID_cuGetErrorName);
216+ Ctx->Cupti .Subscribe (&Ctx->Subscriber , cuptiCallback, Ctx);
217+ Ctx->Cupti .EnableDomain (1 , Ctx->Subscriber , CUPTI_CB_DOMAIN_DRIVER_API);
218+ Ctx->Cupti .EnableCallback (0 , Ctx->Subscriber , CUPTI_CB_DOMAIN_DRIVER_API,
219+ CUPTI_DRIVER_TRACE_CBID_cuGetErrorString);
220+ Ctx->Cupti .EnableCallback (0 , Ctx->Subscriber , CUPTI_CB_DOMAIN_DRIVER_API,
221+ CUPTI_DRIVER_TRACE_CBID_cuGetErrorName);
222+ #else
223+ (void )Ctx;
100224#endif
101225}
102226
103- void disableCUDATracing () {
227+ void disableCUDATracing (cuda_tracing_context_t_ *Ctx ) {
104228#ifdef XPTI_ENABLE_INSTRUMENTATION
105- if (!xptiTraceEnabled ())
229+ if (!Ctx || ! xptiTraceEnabled ())
106230 return ;
107231
232+ if (Ctx->Subscriber && Ctx->Cupti .isInitialized ()) {
233+ Ctx->Cupti .Unsubscribe (Ctx->Subscriber );
234+ Ctx->Subscriber = nullptr ;
235+ }
236+
108237 xptiFinalize (CUDA_CALL_STREAM_NAME);
109238 xptiFinalize (CUDA_DEBUG_STREAM_NAME);
239+ #else
240+ (void )Ctx;
110241#endif // XPTI_ENABLE_INSTRUMENTATION
111242}
0 commit comments