Skip to content

Commit 982aebc

Browse files
committed
Listen to runtime and driver APIs
If a correlation id is associated with a runtime callback use that but if not use the driver API callback. Let the agent know the difference by signing callback ID, driver == negative, runtime == positive. This is necessary because the callback IDs overlap.
1 parent 11ca3c2 commit 982aebc

File tree

1 file changed

+70
-12
lines changed

1 file changed

+70
-12
lines changed

cupti/cupti-prof.c

Lines changed: 70 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ static CUpti_SubscriberHandle subscriber = 0;
2828

2929
static size_t outstandingEvents = 0;
3030

31+
// Thread-local tracking: store correlation ID from runtime ENTER
32+
// so we can skip driver EXIT probe when it matches (driver calls happen under runtime calls)
33+
static __thread uint32_t runtimeEnterCorrelationId = 0;
34+
3135
static void init_debug(void) {
3236
static bool initialized = false;
3337
if (!initialized) {
@@ -87,7 +91,7 @@ int InitializeInjection(void) {
8791
}
8892

8993
// Enable all runtime API kernel launch callbacks
90-
CUpti_CallbackId launchCallbacks[] = {
94+
CUpti_CallbackId runtimeCallbacks[] = {
9195
CUPTI_RUNTIME_TRACE_CBID_cudaLaunch_v3020,
9296
CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000,
9397
CUPTI_RUNTIME_TRACE_CBID_cudaLaunch_ptsz_v7000,
@@ -100,15 +104,42 @@ int InitializeInjection(void) {
100104
CUPTI_RUNTIME_TRACE_CBID_cudaGraphLaunch_v10000,
101105
CUPTI_RUNTIME_TRACE_CBID_cudaGraphLaunch_ptsz_v10000,
102106
};
103-
for (size_t i = 0; i < sizeof(launchCallbacks) / sizeof(launchCallbacks[0]);
107+
for (size_t i = 0; i < sizeof(runtimeCallbacks) / sizeof(runtimeCallbacks[0]);
104108
i++) {
105109
result = cuptiEnableCallback(1, subscriber, CUPTI_CB_DOMAIN_RUNTIME_API,
106-
launchCallbacks[i]);
110+
runtimeCallbacks[i]);
107111
if (result != CUPTI_SUCCESS) {
108112
const char *errstr;
109113
cuptiGetResultString(result, &errstr);
110114
fprintf(stderr, "[CUPTI] Failed to enable runtime callback %d: %s\n",
111-
launchCallbacks[i], errstr);
115+
runtimeCallbacks[i], errstr);
116+
}
117+
}
118+
119+
// Enable all driver API kernel launch callbacks
120+
CUpti_CallbackId driverCallbacks[] = {
121+
CUPTI_DRIVER_TRACE_CBID_cuLaunch,
122+
CUPTI_DRIVER_TRACE_CBID_cuLaunchGrid,
123+
CUPTI_DRIVER_TRACE_CBID_cuLaunchGridAsync,
124+
CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel,
125+
CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel_ptsz,
126+
CUPTI_DRIVER_TRACE_CBID_cuLaunchKernelEx,
127+
CUPTI_DRIVER_TRACE_CBID_cuLaunchKernelEx_ptsz,
128+
CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernel,
129+
CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernel_ptsz,
130+
CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernelMultiDevice,
131+
CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch,
132+
CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch_ptsz,
133+
};
134+
for (size_t i = 0; i < sizeof(driverCallbacks) / sizeof(driverCallbacks[0]);
135+
i++) {
136+
result = cuptiEnableCallback(1, subscriber, CUPTI_CB_DOMAIN_DRIVER_API,
137+
driverCallbacks[i]);
138+
if (result != CUPTI_SUCCESS) {
139+
const char *errstr;
140+
cuptiGetResultString(result, &errstr);
141+
fprintf(stderr, "[CUPTI] Failed to enable driver callback %d: %s\n",
142+
driverCallbacks[i], errstr);
112143
}
113144
}
114145

@@ -159,28 +190,55 @@ static void print_backtrace(const char *prefix) {
159190
}
160191
}
161192

162-
// Callback handler for runtime API
193+
// Callback handler for driver and runtime API
163194
static void parcagpuCuptiCallback(void *userdata, CUpti_CallbackDomain domain,
164195
CUpti_CallbackId cbid,
165196
const CUpti_CallbackData *cbdata) {
166-
if (domain != CUPTI_CB_DOMAIN_RUNTIME_API) {
197+
uint32_t correlationId = cbdata->correlationId;
198+
199+
// Track runtime ENTER so we can skip driver EXIT when they match
200+
if (domain == CUPTI_CB_DOMAIN_RUNTIME_API &&
201+
cbdata->callbackSite == CUPTI_API_ENTER) {
202+
runtimeEnterCorrelationId = correlationId;
167203
return;
168204
}
169205

170206
// We hook on EXIT because that makes our probe overhead not add to GPU
171207
// launch latency and hopefully covers some of the overhead in the shadow of
172208
// GPU async work.
173-
if (cbdata->callbackSite == CUPTI_API_EXIT) {
174-
uint32_t correlationId = cbdata->correlationId;
175-
const char *name =
176-
cbdata->symbolName ? cbdata->symbolName : cbdata->functionName;
209+
if (cbdata->callbackSite != CUPTI_API_EXIT) {
210+
return;
211+
}
177212

213+
const char *name =
214+
cbdata->symbolName ? cbdata->symbolName : cbdata->functionName;
215+
int signedCbid;
216+
217+
if (domain == CUPTI_CB_DOMAIN_DRIVER_API) {
218+
// Skip if this driver call is under a runtime call (same correlation ID)
219+
if (correlationId == runtimeEnterCorrelationId) {
220+
DEBUG_PRINTF(
221+
"[CUPTI] Skipping driver EXIT correlationId=%u - runtime will handle\n",
222+
correlationId);
223+
return;
224+
}
225+
// Pure driver call (no runtime wrapper) - use negative cbid
226+
signedCbid = -(int)cbid;
227+
DEBUG_PRINTF(
228+
"[CUPTI] Driver API callback: cbid=%d, correlationId=%u, func=%s\n",
229+
cbid, correlationId, name);
230+
} else if (domain == CUPTI_CB_DOMAIN_RUNTIME_API) {
231+
signedCbid = (int)cbid;
232+
runtimeEnterCorrelationId = 0; // Clear after use
178233
DEBUG_PRINTF(
179234
"[CUPTI] Runtime API callback: cbid=%d, correlationId=%u, func=%s\n",
180235
cbid, correlationId, name);
181-
outstandingEvents++;
182-
DTRACE_PROBE3(parcagpu, cuda_correlation, correlationId, cbid, name);
236+
} else {
237+
return;
183238
}
239+
240+
outstandingEvents++;
241+
DTRACE_PROBE3(parcagpu, cuda_correlation, correlationId, signedCbid, name);
184242
// If we let too many events pile up it overwhelms the perf_event buffers,
185243
// just another reason to explore just passing the activity buffer through to
186244
// eBPF.

0 commit comments

Comments
 (0)