Skip to content

Commit 311afb9

Browse files
Pierre-Andre Saulaiskbenzie
authored andcommitted
[CUDA} Dynamically load the CUPTI library when tracing
1 parent 6fb1e54 commit 311afb9

File tree

4 files changed

+175
-22
lines changed

4 files changed

+175
-22
lines changed

source/adapters/cuda/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ else()
7676
message(WARNING "CUDA adapter USM pools are disabled, set UMF_ENABLE_POOL_TRACKING to enable them")
7777
endif()
7878

79+
if (CUDA_cupti_LIBRARY)
80+
target_compile_definitions("ur_adapter_cuda" PRIVATE CUPTI_LIB_PATH="${CUDA_cupti_LIBRARY}")
81+
endif()
82+
7983
target_link_libraries(${TARGET_NAME} PRIVATE
8084
${PROJECT_NAME}::headers
8185
${PROJECT_NAME}::common

source/adapters/cuda/adapter.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,12 @@
1111
#include <ur_api.h>
1212

1313
#include "common.hpp"
14-
15-
void enableCUDATracing();
16-
void disableCUDATracing();
14+
#include "tracing.hpp"
1715

1816
struct ur_adapter_handle_t_ {
1917
std::atomic<uint32_t> RefCount = 0;
2018
std::mutex Mutex;
19+
struct cuda_tracing_context_t_ *TracingCtx = nullptr;
2120
};
2221

2322
ur_adapter_handle_t_ adapter{};
@@ -28,7 +27,8 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
2827
if (NumEntries > 0 && phAdapters) {
2928
std::lock_guard<std::mutex> Lock{adapter.Mutex};
3029
if (adapter.RefCount++ == 0) {
31-
enableCUDATracing();
30+
adapter.TracingCtx = createCUDATracingContext();
31+
enableCUDATracing(adapter.TracingCtx);
3232
}
3333

3434
*phAdapters = &adapter;
@@ -50,7 +50,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) {
5050
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
5151
std::lock_guard<std::mutex> Lock{adapter.Mutex};
5252
if (--adapter.RefCount == 0) {
53-
disableCUDATracing();
53+
disableCUDATracing(adapter.TracingCtx);
54+
freeCUDATracingContext(adapter.TracingCtx);
55+
adapter.TracingCtx = nullptr;
5456
}
5557
return UR_RESULT_SUCCESS;
5658
}

source/adapters/cuda/tracing.cpp

Lines changed: 140 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,29 +16,72 @@
1616
#include <cuda.h>
1717
#ifdef XPTI_ENABLE_INSTRUMENTATION
1818
#include <cupti.h>
19+
#include <dlfcn.h>
1920
#endif // XPTI_ENABLE_INSTRUMENTATION
2021

22+
#include "tracing.hpp"
2123
#include <exception>
2224
#include <iostream>
2325

26+
#ifdef XPTI_ENABLE_INSTRUMENTATION
27+
using tracing_event_t = xpti_td *;
28+
using subscriber_handle_t = CUpti_SubscriberHandle;
29+
30+
using cuptiSubscribe_fn = CUPTIAPI
31+
CUptiResult (*)(CUpti_SubscriberHandle *subscriber, CUpti_CallbackFunc callback,
32+
void *userdata);
33+
34+
using cuptiUnsubscribe_fn = CUPTIAPI
35+
CUptiResult (*)(CUpti_SubscriberHandle subscriber);
36+
37+
using cuptiEnableDomain_fn = CUPTIAPI
38+
CUptiResult (*)(uint32_t enable, CUpti_SubscriberHandle subscriber,
39+
CUpti_CallbackDomain domain);
40+
41+
using cuptiEnableCallback_fn = CUPTIAPI
42+
CUptiResult (*)(uint32_t enable, CUpti_SubscriberHandle subscriber,
43+
CUpti_CallbackDomain domain, CUpti_CallbackId cbid);
44+
45+
#define LOAD_CUPTI_SYM(p, x) \
46+
p->x = (cupti##x##_fn)dlsym(p->Library, "cupti" #x);
47+
48+
#else
49+
using tracing_event_t = void *;
50+
using subscriber_handle_t = void *;
51+
using cuptiSubscribe_fn = void *;
52+
using cuptiUnsubscribe_fn = void *;
53+
using cuptiEnableDomain_fn = void *;
54+
using cuptiEnableCallback_fn = void *;
55+
#endif // XPTI_ENABLE_INSTRUMENTATION
56+
57+
struct cuda_tracing_context_t_ {
58+
tracing_event_t CallEvent = nullptr;
59+
tracing_event_t DebugEvent = nullptr;
60+
subscriber_handle_t Subscriber = nullptr;
61+
void *Library = nullptr;
62+
cuptiSubscribe_fn Subscribe = nullptr;
63+
cuptiUnsubscribe_fn Unsubscribe = nullptr;
64+
cuptiEnableDomain_fn EnableDomain = nullptr;
65+
cuptiEnableCallback_fn EnableCallback = nullptr;
66+
};
67+
2468
#ifdef XPTI_ENABLE_INSTRUMENTATION
2569
constexpr auto CUDA_CALL_STREAM_NAME = "sycl.experimental.cuda.call";
2670
constexpr auto CUDA_DEBUG_STREAM_NAME = "sycl.experimental.cuda.debug";
2771

2872
thread_local uint64_t CallCorrelationID = 0;
2973
thread_local uint64_t DebugCorrelationID = 0;
3074

31-
static xpti_td *GCallEvent = nullptr;
32-
static xpti_td *GDebugEvent = nullptr;
33-
3475
constexpr auto GVerStr = "0.1";
3576
constexpr int GMajVer = 0;
3677
constexpr int GMinVer = 1;
3778

38-
static void cuptiCallback(void *, CUpti_CallbackDomain, CUpti_CallbackId CBID,
39-
const void *CBData) {
79+
static void cuptiCallback(void *UserData, CUpti_CallbackDomain,
80+
CUpti_CallbackId CBID, const void *CBData) {
4081
if (xptiTraceEnabled()) {
4182
const auto *CBInfo = static_cast<const CUpti_CallbackData *>(CBData);
83+
cuda_tracing_context_t_ *Ctx =
84+
static_cast<cuda_tracing_context_t_ *>(UserData);
4285

4386
if (CBInfo->callbackSite == CUPTI_API_ENTER) {
4487
CallCorrelationID = xptiGetUniqueId();
@@ -57,22 +100,94 @@ static void cuptiCallback(void *, CUpti_CallbackDomain, CUpti_CallbackId CBID,
57100
uint8_t CallStreamID = xptiRegisterStream(CUDA_CALL_STREAM_NAME);
58101
uint8_t DebugStreamID = xptiRegisterStream(CUDA_DEBUG_STREAM_NAME);
59102

60-
xptiNotifySubscribers(CallStreamID, TraceType, GCallEvent, nullptr,
103+
xptiNotifySubscribers(CallStreamID, TraceType, Ctx->CallEvent, nullptr,
61104
CallCorrelationID, FuncName);
62105

63106
xpti::function_with_args_t Payload{
64107
FuncID, FuncName, const_cast<void *>(CBInfo->functionParams),
65108
CBInfo->functionReturnValue, CBInfo->context};
66-
xptiNotifySubscribers(DebugStreamID, TraceTypeArgs, GDebugEvent, nullptr,
67-
DebugCorrelationID, &Payload);
109+
xptiNotifySubscribers(DebugStreamID, TraceTypeArgs, Ctx->DebugEvent,
110+
nullptr, DebugCorrelationID, &Payload);
68111
}
69112
}
70113
#endif
71114

115+
cuda_tracing_context_t_ *createCUDATracingContext() {
116+
#ifdef XPTI_ENABLE_INSTRUMENTATION
117+
if (!xptiTraceEnabled())
118+
return nullptr;
119+
return new cuda_tracing_context_t_;
120+
#else
121+
return nullptr;
122+
#endif // XPTI_ENABLE_INSTRUMENTATION
123+
}
124+
125+
void freeCUDATracingContext(cuda_tracing_context_t_ *Ctx) {
126+
#ifdef XPTI_ENABLE_INSTRUMENTATION
127+
unloadCUDATracingLibrary(Ctx);
128+
delete Ctx;
129+
#else
130+
(void)Ctx;
131+
#endif // XPTI_ENABLE_INSTRUMENTATION
132+
}
133+
134+
bool loadCUDATracingLibrary(cuda_tracing_context_t_ *Ctx) {
135+
#if defined(XPTI_ENABLE_INSTRUMENTATION) && defined(CUPTI_LIB_PATH)
136+
if (!Ctx)
137+
return false;
138+
if (Ctx->Library)
139+
return true;
140+
Ctx->Library = dlopen(CUPTI_LIB_PATH, RTLD_NOW);
141+
if (!Ctx->Library)
142+
return false;
143+
LOAD_CUPTI_SYM(Ctx, Subscribe)
144+
LOAD_CUPTI_SYM(Ctx, Unsubscribe)
145+
LOAD_CUPTI_SYM(Ctx, EnableDomain)
146+
LOAD_CUPTI_SYM(Ctx, EnableCallback)
147+
if (!Ctx->Subscribe || !Ctx->Unsubscribe || !Ctx->EnableDomain ||
148+
!Ctx->EnableCallback) {
149+
unloadCUDATracingLibrary(Ctx);
150+
return false;
151+
}
152+
return true;
153+
#else
154+
(void)Ctx;
155+
return false;
156+
#endif // XPTI_ENABLE_INSTRUMENTATION && CUPTI_LIB_PATH
157+
}
158+
159+
void unloadCUDATracingLibrary(cuda_tracing_context_t_ *Ctx) {
160+
#ifdef XPTI_ENABLE_INSTRUMENTATION
161+
if (!Ctx || !Ctx->Library)
162+
return;
163+
Ctx->Subscribe = nullptr;
164+
Ctx->Unsubscribe = nullptr;
165+
Ctx->EnableDomain = nullptr;
166+
Ctx->EnableCallback = nullptr;
167+
dlclose(Ctx->Library);
168+
Ctx->Library = nullptr;
169+
#else
170+
(void)Ctx;
171+
#endif // XPTI_ENABLE_INSTRUMENTATION
172+
}
173+
72174
void enableCUDATracing() {
73175
#ifdef XPTI_ENABLE_INSTRUMENTATION
74176
if (!xptiTraceEnabled())
75177
return;
178+
static cuda_tracing_context_t_ *Ctx = nullptr;
179+
if (!Ctx)
180+
Ctx = createCUDATracingContext();
181+
enableCUDATracing(Ctx);
182+
#endif
183+
}
184+
185+
void enableCUDATracing(cuda_tracing_context_t_ *Ctx) {
186+
#ifdef XPTI_ENABLE_INSTRUMENTATION
187+
if (!Ctx || !xptiTraceEnabled())
188+
return;
189+
else if (!loadCUDATracingLibrary(Ctx))
190+
return;
76191

77192
xptiRegisterStream(CUDA_CALL_STREAM_NAME);
78193
xptiInitialize(CUDA_CALL_STREAM_NAME, GMajVer, GMinVer, GVerStr);
@@ -81,31 +196,39 @@ void enableCUDATracing() {
81196

82197
uint64_t Dummy;
83198
xpti::payload_t CUDAPayload("CUDA Plugin Layer");
84-
GCallEvent =
199+
Ctx->CallEvent =
85200
xptiMakeEvent("CUDA Plugin Layer", &CUDAPayload,
86201
xpti::trace_algorithm_event, xpti_at::active, &Dummy);
87202

88203
xpti::payload_t CUDADebugPayload("CUDA Plugin Debug Layer");
89-
GDebugEvent =
204+
Ctx->DebugEvent =
90205
xptiMakeEvent("CUDA Plugin Debug Layer", &CUDADebugPayload,
91206
xpti::trace_algorithm_event, xpti_at::active, &Dummy);
92207

93-
CUpti_SubscriberHandle Subscriber;
94-
cuptiSubscribe(&Subscriber, cuptiCallback, nullptr);
95-
cuptiEnableDomain(1, Subscriber, CUPTI_CB_DOMAIN_DRIVER_API);
96-
cuptiEnableCallback(0, Subscriber, CUPTI_CB_DOMAIN_DRIVER_API,
208+
Ctx->Subscribe(&Ctx->Subscriber, cuptiCallback, Ctx);
209+
Ctx->EnableDomain(1, Ctx->Subscriber, CUPTI_CB_DOMAIN_DRIVER_API);
210+
Ctx->EnableCallback(0, Ctx->Subscriber, CUPTI_CB_DOMAIN_DRIVER_API,
97211
CUPTI_DRIVER_TRACE_CBID_cuGetErrorString);
98-
cuptiEnableCallback(0, Subscriber, CUPTI_CB_DOMAIN_DRIVER_API,
212+
Ctx->EnableCallback(0, Ctx->Subscriber, CUPTI_CB_DOMAIN_DRIVER_API,
99213
CUPTI_DRIVER_TRACE_CBID_cuGetErrorName);
214+
#else
215+
(void)Ctx;
100216
#endif
101217
}
102218

103-
void disableCUDATracing() {
219+
void disableCUDATracing(cuda_tracing_context_t_ *Ctx) {
104220
#ifdef XPTI_ENABLE_INSTRUMENTATION
105-
if (!xptiTraceEnabled())
221+
if (!Ctx || !xptiTraceEnabled())
106222
return;
107223

224+
if (Ctx->Subscriber) {
225+
Ctx->Unsubscribe(Ctx->Subscriber);
226+
Ctx->Subscriber = nullptr;
227+
}
228+
108229
xptiFinalize(CUDA_CALL_STREAM_NAME);
109230
xptiFinalize(CUDA_DEBUG_STREAM_NAME);
231+
#else
232+
(void)Ctx;
110233
#endif // XPTI_ENABLE_INSTRUMENTATION
111234
}

source/adapters/cuda/tracing.hpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
//===--------- tracing.hpp - CUDA Host API Tracing -------------------------==//
2+
//
3+
// Copyright (C) 2023 Intel Corporation
4+
//
5+
// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM
6+
// Exceptions. See LICENSE.TXT
7+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8+
//
9+
//===----------------------------------------------------------------------===//
10+
11+
struct cuda_tracing_context_t_;
12+
13+
cuda_tracing_context_t_ *createCUDATracingContext();
14+
void freeCUDATracingContext(cuda_tracing_context_t_ *Ctx);
15+
16+
bool loadCUDATracingLibrary(cuda_tracing_context_t_ *Ctx);
17+
void unloadCUDATracingLibrary(cuda_tracing_context_t_ *Ctx);
18+
19+
void enableCUDATracing(cuda_tracing_context_t_ *Ctx);
20+
void disableCUDATracing(cuda_tracing_context_t_ *Ctx);
21+
22+
// Deprecated. Will be removed once pi_cuda has been updated to use the variant
23+
// that takes a context pointer.
24+
void enableCUDATracing();

0 commit comments

Comments
 (0)