@@ -24,6 +24,7 @@ namespace ur_tracing_layer {
2424context_t *getContext () { return context_t::get_direct (); }
2525
2626constexpr auto CALL_STREAM_NAME = " ur.call" ;
27+ constexpr auto DEBUG_CALL_STREAM_NAME = " ur.call.debug" ;
2728constexpr auto STREAM_VER_MAJOR = UR_MAJOR_VERSION(UR_API_VERSION_CURRENT);
2829constexpr auto STREAM_VER_MINOR = UR_MINOR_VERSION(UR_API_VERSION_CURRENT);
2930
@@ -51,10 +52,13 @@ context_t::context_t() : logger(logger::create_logger("tracing", true, true)) {
5152 this ->xptiContextManager = xptiContextManagerGet ();
5253
5354 call_stream_id = xptiRegisterStream (CALL_STREAM_NAME);
55+ debug_call_stream_id = xptiRegisterStream (DEBUG_CALL_STREAM_NAME);
5456 std::ostringstream streamv;
5557 streamv << STREAM_VER_MAJOR << " ." << STREAM_VER_MINOR;
5658 xptiInitialize (CALL_STREAM_NAME, STREAM_VER_MAJOR, STREAM_VER_MINOR,
5759 streamv.str ().data ());
60+ xptiInitialize (DEBUG_CALL_STREAM_NAME, STREAM_VER_MAJOR, STREAM_VER_MINOR,
61+ streamv.str ().data ());
5862 // Create global event for all UR API calls.
5963 xpti_tracepoint_t *Event =
6064 xptiCreateTracepoint (" Unified Runtime call" , nullptr , 0 , 0 , (void *)this );
@@ -66,35 +70,36 @@ context_t::context_t() : logger(logger::create_logger("tracing", true, true)) {
6670void context_t::notify (uint16_t trace_type, uint32_t id, const char *name,
6771 void *args, ur_result_t *resultp, uint64_t instance) {
6872 xpti::function_with_args_t payload{id, name, args, resultp, nullptr };
69- // Use global event for all UR API calls
70- xptiNotifySubscribers (call_stream_id, trace_type, nullptr , activeEvent,
71- instance, &payload);
73+ if (xptiCheckTraceEnabled (debug_call_stream_id)) {
74+ xptiNotifySubscribers (debug_call_stream_id, trace_type, nullptr ,
75+ activeEvent, instance, &payload);
76+ } else {
77+ // Use global event for all UR API calls
78+ if (xptiCheckTraceEnabled (call_stream_id))
79+ xptiNotifySubscribers (call_stream_id, trace_type, nullptr , activeEvent,
80+ instance, &payload);
81+ }
7282}
7383
7484uint64_t context_t::notify_begin (uint32_t id, const char *name, void *args) {
75- // we use UINT64_MAX as a special value that means "tracing disabled",
76- // so that we don't have to repeat this check in notify_end.
77- if (!xptiCheckTraceEnabled (call_stream_id)) {
85+ if (xptiCheckTraceEnabled (debug_call_stream_id)) {
86+ // Create a new tracepoint with code location info for each UR API call.
87+ // This adds significant overhead to the tracing toolchain, so do this only
88+ // if there are debug stream subscribers.
89+ if (auto loc = codelocData.get_codeloc ()) {
90+ xpti_tracepoint_t *Event = xptiCreateTracepoint (
91+ loc->functionName , loc->sourceFile , loc->lineNumber ,
92+ loc->columnNumber , (void *)this );
93+ activeEvent = Event ? Event->event_ref () : nullptr ;
94+ }
95+ } else if (xptiCheckTraceEnabled (call_stream_id)) {
96+ // Otherwise use global event for all UR API calls.
97+ activeEvent = GURCallEvent;
98+ } else {
99+ // We use UINT64_MAX as a special value that means "tracing disabled",
100+ // so that we don't have to repeat this check in notify_end.
78101 return UINT64_MAX;
79102 }
80-
81- // Previous implementation created a new event for each UR API call. This
82- // adds significant overhead to the tracing toolchain. Replacing the
83- // previous code with a single global event for all UR API calls:
84- //
85- // PREVIOUS CODE:
86- // if (auto loc = codelocData.get_codeloc()) {
87- // xpti::payload_t payload =
88- // xpti::payload_t(loc->functionName, loc->sourceFile, loc->lineNumber,
89- // loc->columnNumber, nullptr);
90- // uint64_t InstanceNumber{};
91- // activeEvent =
92- // xptiMakeEvent("Unified Runtime call", &payload,
93- // xpti::trace_graph_event,
94- // xpti_at::active, &InstanceNumber);
95- // }
96-
97- activeEvent = GURCallEvent;
98103 uint64_t instance = xptiGetUniqueId ();
99104 notify ((uint16_t )xpti::trace_point_type_t ::function_with_args_begin, id, name,
100105 args, nullptr , instance);
0 commit comments