16
16
#include < cuda.h>
17
17
#ifdef XPTI_ENABLE_INSTRUMENTATION
18
18
#include < cupti.h>
19
+ #include < dlfcn.h>
19
20
#endif // XPTI_ENABLE_INSTRUMENTATION
20
21
22
+ #include " tracing.hpp"
21
23
#include < exception>
22
24
#include < iostream>
23
25
26
+ #ifdef XPTI_ENABLE_INSTRUMENTATION
27
+ using tracing_event_t = xpti_td *;
28
+ using subscriber_handle_t = CUpti_SubscriberHandle;
29
+
30
+ using cuptiSubscribe_fn = CUPTIAPI
31
+ CUptiResult (*)(CUpti_SubscriberHandle *subscriber, CUpti_CallbackFunc callback,
32
+ void *userdata);
33
+
34
+ using cuptiUnsubscribe_fn = CUPTIAPI
35
+ CUptiResult (*)(CUpti_SubscriberHandle subscriber);
36
+
37
+ using cuptiEnableDomain_fn = CUPTIAPI
38
+ CUptiResult (*)(uint32_t enable, CUpti_SubscriberHandle subscriber,
39
+ CUpti_CallbackDomain domain);
40
+
41
+ using cuptiEnableCallback_fn = CUPTIAPI
42
+ CUptiResult (*)(uint32_t enable, CUpti_SubscriberHandle subscriber,
43
+ CUpti_CallbackDomain domain, CUpti_CallbackId cbid);
44
+
45
+ #define LOAD_CUPTI_SYM (p, x ) \
46
+ p->x = (cupti##x##_fn)dlsym(p->Library, " cupti" #x);
47
+
48
+ #else
49
+ using tracing_event_t = void *;
50
+ using subscriber_handle_t = void *;
51
+ using cuptiSubscribe_fn = void *;
52
+ using cuptiUnsubscribe_fn = void *;
53
+ using cuptiEnableDomain_fn = void *;
54
+ using cuptiEnableCallback_fn = void *;
55
+ #endif // XPTI_ENABLE_INSTRUMENTATION
56
+
57
+ struct cuda_tracing_context_t_ {
58
+ tracing_event_t CallEvent = nullptr ;
59
+ tracing_event_t DebugEvent = nullptr ;
60
+ subscriber_handle_t Subscriber = nullptr ;
61
+ void *Library = nullptr ;
62
+ cuptiSubscribe_fn Subscribe = nullptr ;
63
+ cuptiUnsubscribe_fn Unsubscribe = nullptr ;
64
+ cuptiEnableDomain_fn EnableDomain = nullptr ;
65
+ cuptiEnableCallback_fn EnableCallback = nullptr ;
66
+ };
67
+
24
68
#ifdef XPTI_ENABLE_INSTRUMENTATION
25
69
constexpr auto CUDA_CALL_STREAM_NAME = " sycl.experimental.cuda.call" ;
26
70
constexpr auto CUDA_DEBUG_STREAM_NAME = " sycl.experimental.cuda.debug" ;
27
71
28
72
thread_local uint64_t CallCorrelationID = 0 ;
29
73
thread_local uint64_t DebugCorrelationID = 0 ;
30
74
31
- static xpti_td *GCallEvent = nullptr ;
32
- static xpti_td *GDebugEvent = nullptr ;
33
-
34
75
constexpr auto GVerStr = " 0.1" ;
35
76
constexpr int GMajVer = 0 ;
36
77
constexpr int GMinVer = 1 ;
37
78
38
- static void cuptiCallback (void *, CUpti_CallbackDomain, CUpti_CallbackId CBID ,
39
- const void *CBData) {
79
+ static void cuptiCallback (void *UserData , CUpti_CallbackDomain,
80
+ CUpti_CallbackId CBID, const void *CBData) {
40
81
if (xptiTraceEnabled ()) {
41
82
const auto *CBInfo = static_cast <const CUpti_CallbackData *>(CBData);
83
+ cuda_tracing_context_t_ *Ctx =
84
+ static_cast <cuda_tracing_context_t_ *>(UserData);
42
85
43
86
if (CBInfo->callbackSite == CUPTI_API_ENTER) {
44
87
CallCorrelationID = xptiGetUniqueId ();
@@ -57,22 +100,94 @@ static void cuptiCallback(void *, CUpti_CallbackDomain, CUpti_CallbackId CBID,
57
100
uint8_t CallStreamID = xptiRegisterStream (CUDA_CALL_STREAM_NAME);
58
101
uint8_t DebugStreamID = xptiRegisterStream (CUDA_DEBUG_STREAM_NAME);
59
102
60
- xptiNotifySubscribers (CallStreamID, TraceType, GCallEvent , nullptr ,
103
+ xptiNotifySubscribers (CallStreamID, TraceType, Ctx-> CallEvent , nullptr ,
61
104
CallCorrelationID, FuncName);
62
105
63
106
xpti::function_with_args_t Payload{
64
107
FuncID, FuncName, const_cast <void *>(CBInfo->functionParams ),
65
108
CBInfo->functionReturnValue , CBInfo->context };
66
- xptiNotifySubscribers (DebugStreamID, TraceTypeArgs, GDebugEvent, nullptr ,
67
- DebugCorrelationID, &Payload);
109
+ xptiNotifySubscribers (DebugStreamID, TraceTypeArgs, Ctx-> DebugEvent ,
110
+ nullptr , DebugCorrelationID, &Payload);
68
111
}
69
112
}
70
113
#endif
71
114
115
+ cuda_tracing_context_t_ *createCUDATracingContext () {
116
+ #ifdef XPTI_ENABLE_INSTRUMENTATION
117
+ if (!xptiTraceEnabled ())
118
+ return nullptr ;
119
+ return new cuda_tracing_context_t_;
120
+ #else
121
+ return nullptr ;
122
+ #endif // XPTI_ENABLE_INSTRUMENTATION
123
+ }
124
+
125
+ void freeCUDATracingContext (cuda_tracing_context_t_ *Ctx) {
126
+ #ifdef XPTI_ENABLE_INSTRUMENTATION
127
+ unloadCUDATracingLibrary (Ctx);
128
+ delete Ctx;
129
+ #else
130
+ (void )Ctx;
131
+ #endif // XPTI_ENABLE_INSTRUMENTATION
132
+ }
133
+
134
+ bool loadCUDATracingLibrary (cuda_tracing_context_t_ *Ctx) {
135
+ #if defined(XPTI_ENABLE_INSTRUMENTATION) && defined(CUPTI_LIB_PATH)
136
+ if (!Ctx)
137
+ return false ;
138
+ if (Ctx->Library )
139
+ return true ;
140
+ Ctx->Library = dlopen (CUPTI_LIB_PATH, RTLD_NOW);
141
+ if (!Ctx->Library )
142
+ return false ;
143
+ LOAD_CUPTI_SYM (Ctx, Subscribe)
144
+ LOAD_CUPTI_SYM (Ctx, Unsubscribe)
145
+ LOAD_CUPTI_SYM (Ctx, EnableDomain)
146
+ LOAD_CUPTI_SYM (Ctx, EnableCallback)
147
+ if (!Ctx->Subscribe || !Ctx->Unsubscribe || !Ctx->EnableDomain ||
148
+ !Ctx->EnableCallback ) {
149
+ unloadCUDATracingLibrary (Ctx);
150
+ return false ;
151
+ }
152
+ return true ;
153
+ #else
154
+ (void )Ctx;
155
+ return false ;
156
+ #endif // XPTI_ENABLE_INSTRUMENTATION && CUPTI_LIB_PATH
157
+ }
158
+
159
+ void unloadCUDATracingLibrary (cuda_tracing_context_t_ *Ctx) {
160
+ #ifdef XPTI_ENABLE_INSTRUMENTATION
161
+ if (!Ctx || !Ctx->Library )
162
+ return ;
163
+ Ctx->Subscribe = nullptr ;
164
+ Ctx->Unsubscribe = nullptr ;
165
+ Ctx->EnableDomain = nullptr ;
166
+ Ctx->EnableCallback = nullptr ;
167
+ dlclose (Ctx->Library );
168
+ Ctx->Library = nullptr ;
169
+ #else
170
+ (void )Ctx;
171
+ #endif // XPTI_ENABLE_INSTRUMENTATION
172
+ }
173
+
72
174
void enableCUDATracing () {
73
175
#ifdef XPTI_ENABLE_INSTRUMENTATION
74
176
if (!xptiTraceEnabled ())
75
177
return ;
178
+ static cuda_tracing_context_t_ *Ctx = nullptr ;
179
+ if (!Ctx)
180
+ Ctx = createCUDATracingContext ();
181
+ enableCUDATracing (Ctx);
182
+ #endif
183
+ }
184
+
185
+ void enableCUDATracing (cuda_tracing_context_t_ *Ctx) {
186
+ #ifdef XPTI_ENABLE_INSTRUMENTATION
187
+ if (!Ctx || !xptiTraceEnabled ())
188
+ return ;
189
+ else if (!loadCUDATracingLibrary (Ctx))
190
+ return ;
76
191
77
192
xptiRegisterStream (CUDA_CALL_STREAM_NAME);
78
193
xptiInitialize (CUDA_CALL_STREAM_NAME, GMajVer, GMinVer, GVerStr);
@@ -81,31 +196,39 @@ void enableCUDATracing() {
81
196
82
197
uint64_t Dummy;
83
198
xpti::payload_t CUDAPayload (" CUDA Plugin Layer" );
84
- GCallEvent =
199
+ Ctx-> CallEvent =
85
200
xptiMakeEvent (" CUDA Plugin Layer" , &CUDAPayload,
86
201
xpti::trace_algorithm_event, xpti_at::active, &Dummy);
87
202
88
203
xpti::payload_t CUDADebugPayload (" CUDA Plugin Debug Layer" );
89
- GDebugEvent =
204
+ Ctx-> DebugEvent =
90
205
xptiMakeEvent (" CUDA Plugin Debug Layer" , &CUDADebugPayload,
91
206
xpti::trace_algorithm_event, xpti_at::active, &Dummy);
92
207
93
- CUpti_SubscriberHandle Subscriber;
94
- cuptiSubscribe (&Subscriber, cuptiCallback, nullptr );
95
- cuptiEnableDomain (1 , Subscriber, CUPTI_CB_DOMAIN_DRIVER_API);
96
- cuptiEnableCallback (0 , Subscriber, CUPTI_CB_DOMAIN_DRIVER_API,
208
+ Ctx->Subscribe (&Ctx->Subscriber , cuptiCallback, Ctx);
209
+ Ctx->EnableDomain (1 , Ctx->Subscriber , CUPTI_CB_DOMAIN_DRIVER_API);
210
+ Ctx->EnableCallback (0 , Ctx->Subscriber , CUPTI_CB_DOMAIN_DRIVER_API,
97
211
CUPTI_DRIVER_TRACE_CBID_cuGetErrorString);
98
- cuptiEnableCallback (0 , Subscriber, CUPTI_CB_DOMAIN_DRIVER_API,
212
+ Ctx-> EnableCallback (0 , Ctx-> Subscriber , CUPTI_CB_DOMAIN_DRIVER_API,
99
213
CUPTI_DRIVER_TRACE_CBID_cuGetErrorName);
214
+ #else
215
+ (void )Ctx;
100
216
#endif
101
217
}
102
218
103
- void disableCUDATracing () {
219
+ void disableCUDATracing (cuda_tracing_context_t_ *Ctx ) {
104
220
#ifdef XPTI_ENABLE_INSTRUMENTATION
105
- if (!xptiTraceEnabled ())
221
+ if (!Ctx || ! xptiTraceEnabled ())
106
222
return ;
107
223
224
+ if (Ctx->Subscriber ) {
225
+ Ctx->Unsubscribe (Ctx->Subscriber );
226
+ Ctx->Subscriber = nullptr ;
227
+ }
228
+
108
229
xptiFinalize (CUDA_CALL_STREAM_NAME);
109
230
xptiFinalize (CUDA_DEBUG_STREAM_NAME);
231
+ #else
232
+ (void )Ctx;
110
233
#endif // XPTI_ENABLE_INSTRUMENTATION
111
234
}
0 commit comments