@@ -28,6 +28,10 @@ static CUpti_SubscriberHandle subscriber = 0;
2828
2929static size_t outstandingEvents = 0 ;
3030
31+ // Thread-local tracking: store correlation ID from runtime ENTER
32+ // so we can skip driver EXIT probe when it matches (driver calls happen under runtime calls)
33+ static __thread uint32_t runtimeEnterCorrelationId = 0 ;
34+
3135static void init_debug (void ) {
3236 static bool initialized = false;
3337 if (!initialized ) {
@@ -87,7 +91,7 @@ int InitializeInjection(void) {
8791 }
8892
8993 // Enable all runtime API kernel launch callbacks
90- CUpti_CallbackId launchCallbacks [] = {
94+ CUpti_CallbackId runtimeCallbacks [] = {
9195 CUPTI_RUNTIME_TRACE_CBID_cudaLaunch_v3020 ,
9296 CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000 ,
9397 CUPTI_RUNTIME_TRACE_CBID_cudaLaunch_ptsz_v7000 ,
@@ -100,15 +104,42 @@ int InitializeInjection(void) {
100104 CUPTI_RUNTIME_TRACE_CBID_cudaGraphLaunch_v10000 ,
101105 CUPTI_RUNTIME_TRACE_CBID_cudaGraphLaunch_ptsz_v10000 ,
102106 };
103- for (size_t i = 0 ; i < sizeof (launchCallbacks ) / sizeof (launchCallbacks [0 ]);
107+ for (size_t i = 0 ; i < sizeof (runtimeCallbacks ) / sizeof (runtimeCallbacks [0 ]);
104108 i ++ ) {
105109 result = cuptiEnableCallback (1 , subscriber , CUPTI_CB_DOMAIN_RUNTIME_API ,
106- launchCallbacks [i ]);
110+ runtimeCallbacks [i ]);
107111 if (result != CUPTI_SUCCESS ) {
108112 const char * errstr ;
109113 cuptiGetResultString (result , & errstr );
110114 fprintf (stderr , "[CUPTI] Failed to enable runtime callback %d: %s\n" ,
111- launchCallbacks [i ], errstr );
115+ runtimeCallbacks [i ], errstr );
116+ }
117+ }
118+
119+ // Enable all driver API kernel launch callbacks
120+ CUpti_CallbackId driverCallbacks [] = {
121+ CUPTI_DRIVER_TRACE_CBID_cuLaunch ,
122+ CUPTI_DRIVER_TRACE_CBID_cuLaunchGrid ,
123+ CUPTI_DRIVER_TRACE_CBID_cuLaunchGridAsync ,
124+ CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel ,
125+ CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel_ptsz ,
126+ CUPTI_DRIVER_TRACE_CBID_cuLaunchKernelEx ,
127+ CUPTI_DRIVER_TRACE_CBID_cuLaunchKernelEx_ptsz ,
128+ CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernel ,
129+ CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernel_ptsz ,
130+ CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernelMultiDevice ,
131+ CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch ,
132+ CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch_ptsz ,
133+ };
134+ for (size_t i = 0 ; i < sizeof (driverCallbacks ) / sizeof (driverCallbacks [0 ]);
135+ i ++ ) {
136+ result = cuptiEnableCallback (1 , subscriber , CUPTI_CB_DOMAIN_DRIVER_API ,
137+ driverCallbacks [i ]);
138+ if (result != CUPTI_SUCCESS ) {
139+ const char * errstr ;
140+ cuptiGetResultString (result , & errstr );
141+ fprintf (stderr , "[CUPTI] Failed to enable driver callback %d: %s\n" ,
142+ driverCallbacks [i ], errstr );
112143 }
113144 }
114145
@@ -159,28 +190,55 @@ static void print_backtrace(const char *prefix) {
159190 }
160191}
161192
162- // Callback handler for runtime API
193+ // Callback handler for driver and runtime API
163194static void parcagpuCuptiCallback (void * userdata , CUpti_CallbackDomain domain ,
164195 CUpti_CallbackId cbid ,
165196 const CUpti_CallbackData * cbdata ) {
166- if (domain != CUPTI_CB_DOMAIN_RUNTIME_API ) {
197+ uint32_t correlationId = cbdata -> correlationId ;
198+
199+ // Track runtime ENTER so we can skip driver EXIT when they match
200+ if (domain == CUPTI_CB_DOMAIN_RUNTIME_API &&
201+ cbdata -> callbackSite == CUPTI_API_ENTER ) {
202+ runtimeEnterCorrelationId = correlationId ;
167203 return ;
168204 }
169205
170206 // We hook on EXIT because that makes our probe overhead not add to GPU
171207 // launch latency and hopefully covers some of the overhead in the shadow of
172208 // GPU async work.
173- if (cbdata -> callbackSite == CUPTI_API_EXIT ) {
174- uint32_t correlationId = cbdata -> correlationId ;
175- const char * name =
176- cbdata -> symbolName ? cbdata -> symbolName : cbdata -> functionName ;
209+ if (cbdata -> callbackSite != CUPTI_API_EXIT ) {
210+ return ;
211+ }
177212
213+ const char * name =
214+ cbdata -> symbolName ? cbdata -> symbolName : cbdata -> functionName ;
215+ int signedCbid ;
216+
217+ if (domain == CUPTI_CB_DOMAIN_DRIVER_API ) {
218+ // Skip if this driver call is under a runtime call (same correlation ID)
219+ if (correlationId == runtimeEnterCorrelationId ) {
220+ DEBUG_PRINTF (
221+ "[CUPTI] Skipping driver EXIT correlationId=%u - runtime will handle\n" ,
222+ correlationId );
223+ return ;
224+ }
225+ // Pure driver call (no runtime wrapper) - use negative cbid
226+ signedCbid = - (int )cbid ;
227+ DEBUG_PRINTF (
228+ "[CUPTI] Driver API callback: cbid=%d, correlationId=%u, func=%s\n" ,
229+ cbid , correlationId , name );
230+ } else if (domain == CUPTI_CB_DOMAIN_RUNTIME_API ) {
231+ signedCbid = (int )cbid ;
232+ runtimeEnterCorrelationId = 0 ; // Clear after use
178233 DEBUG_PRINTF (
179234 "[CUPTI] Runtime API callback: cbid=%d, correlationId=%u, func=%s\n" ,
180235 cbid , correlationId , name );
181- outstandingEvents ++ ;
182- DTRACE_PROBE3 ( parcagpu , cuda_correlation , correlationId , cbid , name ) ;
236+ } else {
237+ return ;
183238 }
239+
240+ outstandingEvents ++ ;
241+ DTRACE_PROBE3 (parcagpu , cuda_correlation , correlationId , signedCbid , name );
184242 // If we let too many events pile up it overwhelms the perf_event buffers,
185243 // just another reason to explore just passing the activity buffer through to
186244 // eBPF.
0 commit comments