Skip to content

Commit a3d6aa4

Browse files
authored
Use thread adoption to handle log messages. (#2754)
1 parent fe7e7b7 commit a3d6aa4

File tree

2 files changed

+34
-78
lines changed

2 files changed

+34
-78
lines changed

lib/cublas/CUBLAS.jl

Lines changed: 29 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -210,56 +210,45 @@ end
210210

211211
## logging
212212

213-
const MAX_LOG_BUFLEN = UInt(1024*1024)
214-
const log_buffer = Vector{UInt8}(undef, MAX_LOG_BUFLEN)
215-
const log_cursor = Threads.Atomic{UInt}(0)
216-
const log_cond = Ref{Base.AsyncCondition}() # root
213+
# CUBLAS calls the log callback multiple times for each message, so we need to buffer them
214+
const log_buffer = IOBuffer()
217215

218216
function log_message(ptr)
219-
# NOTE: this function may be called from unmanaged threads (by cublasXt),
220-
# so we can't even allocate, let alone perform I/O.
221-
len = @ccall strlen(ptr::Cstring)::Csize_t
222-
old_cursor = log_cursor[]
223-
new_cursor = old_cursor + len+1
224-
if new_cursor >= MAX_LOG_BUFLEN
225-
# overrun
226-
return
227-
end
217+
global log_buffer
218+
str = unsafe_string(ptr)
228219

229-
@ccall memmove((pointer(log_buffer)+old_cursor)::Ptr{Nothing},
230-
pointer(ptr)::Ptr{Nothing}, (len+1)::Csize_t)::Nothing
231-
log_cursor[] = new_cursor # the consumer handles CAS'ing this value
220+
# flush if we've started a new log message
221+
if startswith(str, r"[A-Z]!")
222+
flush_log_messages()
223+
end
232224

233-
# avoid code that depends on the runtime (even the unsafe_convert from ccall does?!)
234-
assume(isassigned(log_cond))
235-
@ccall uv_async_send(log_cond[].handle::Ptr{Nothing})::Cint
225+
# append the lines to the buffer
226+
println(log_buffer, str)
236227

237228
return
238229
end
239230

240-
function _log_message(blob)
231+
function flush_log_messages()
232+
global log_buffer
233+
message = String(take!(log_buffer))
234+
isempty(message) && return
235+
241236
# the message format isn't documented, but it looks like a message starts with a capital
242237
# and the severity (e.g. `I!`), and subsequent lines start with a lowercase mark (`!i`)
243-
#
244-
# lines are separated by a \0 if they came in separately, but there may also be multiple
245-
# actual lines separated by \n in each message.
246-
for message in split(blob, r"[\0\n]+(?=[A-Z]!)")
247-
code = message[1]
248-
lines = split(message[3:end], r"[\0\n]+[a-z]!")
249-
submessage = join(lines, '\n')
250-
if code == 'I'
251-
@debug submessage
252-
elseif code == 'W'
253-
@warn submessage
254-
elseif code == 'E'
255-
@error submessage
256-
elseif code == 'F'
257-
error(submessage)
258-
else
259-
@info "Unknown log message, please file an issue.\n$message"
260-
end
238+
code = message[1]
239+
lines = split(message[3:end], r"\n+[a-z]!")
240+
message = join(strip.(lines), '\n')
241+
if code == 'I'
242+
@debug message
243+
elseif code == 'W'
244+
@warn message
245+
elseif code == 'E'
246+
@error message
247+
elseif code == 'F'
248+
error(message)
249+
else
250+
@info "Unknown log message, please file an issue.\n$message"
261251
end
262-
return
263252
end
264253

265254
function __init__()
@@ -273,21 +262,9 @@ function __init__()
273262
# register a log callback
274263
if !Sys.iswindows() && # NVIDIA bug #3321130 &&
275264
!precompiling && (isdebug(:init, CUBLAS) || Base.JLOptions().debug_level >= 2)
276-
log_cond[] = Base.AsyncCondition() do async_cond
277-
blob = ""
278-
while true
279-
message_length = log_cursor[]
280-
blob = unsafe_string(pointer(log_buffer), message_length)
281-
if Threads.atomic_cas!(log_cursor, message_length, UInt(0)) == message_length
282-
break
283-
end
284-
end
285-
_log_message(blob)
286-
return
287-
end
288-
289265
callback = @cfunction(log_message, Nothing, (Cstring,))
290266
cublasSetLoggerCallback(callback)
267+
atexit(flush_log_messages)
291268
end
292269
end
293270

lib/cudnn/src/cuDNN.jl

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,6 @@ end
116116

117117
## logging
118118

119-
const log_messages = []
120-
const log_lock = ReentrantLock()
121-
const log_cond = Ref{Any}() # root
122-
123119
function log_message(sev, udata, dbg_ptr, ptr)
124120
dbg = unsafe_load(dbg_ptr)
125121

@@ -131,20 +127,11 @@ function log_message(sev, udata, dbg_ptr, ptr)
131127
end
132128
len += 1
133129
end
134-
str = unsafe_string(ptr, len) # XXX: can this yield?
135-
136-
# print asynchronously
137-
Base.@lock log_lock begin
138-
push!(log_messages, (; sev, dbg, str))
139-
end
140-
ccall(:uv_async_send, Cint, (Ptr{Cvoid},), udata)
130+
str = unsafe_string(ptr, len)
141131

142-
return
143-
end
144-
145-
function _log_message(sev, dbg, str)
132+
# split into lines and report
146133
lines = split(str, '\0')
147-
msg = join(lines, '\n')
134+
msg = join(strip.(lines), '\n')
148135
if sev == CUDNN_SEV_INFO
149136
@debug msg
150137
elseif sev == CUDNN_SEV_WARNING
@@ -154,6 +141,7 @@ function _log_message(sev, dbg, str)
154141
elseif sev == CUDNN_SEV_FATAL
155142
error(msg)
156143
end
144+
157145
return
158146
end
159147

@@ -182,18 +170,9 @@ function __init__()
182170

183171
# register a log callback
184172
if !precompiling && (isdebug(:init, cuDNN) || Base.JLOptions().debug_level >= 2)
185-
log_cond[] = Base.AsyncCondition() do async_cond
186-
Base.@lock log_lock begin
187-
while length(log_messages) > 0
188-
message = popfirst!(log_messages)
189-
_log_message(message...)
190-
end
191-
end
192-
end
193-
194173
callback = @cfunction(log_message, Nothing,
195174
(cudnnSeverity_t, Ptr{Cvoid}, Ptr{cudnnDebug_t}, Ptr{UInt8}))
196-
cudnnSetCallback(typemax(UInt32), log_cond[], callback)
175+
cudnnSetCallback(typemax(UInt32), C_NULL, callback)
197176
end
198177

199178
_initialized[] = true

0 commit comments

Comments
 (0)