@@ -210,56 +210,45 @@ end
210210
211211# # logging
212212
213- const MAX_LOG_BUFLEN = UInt (1024 * 1024 )
214- const log_buffer = Vector {UInt8} (undef, MAX_LOG_BUFLEN)
215- const log_cursor = Threads. Atomic {UInt} (0 )
216- const log_cond = Ref {Base.AsyncCondition} () # root
213+ # CUBLAS calls the log callback multiple times for each message, so we need to buffer them
214+ const log_buffer = IOBuffer ()
217215
218216function log_message (ptr)
219- # NOTE: this function may be called from unmanaged threads (by cublasXt),
220- # so we can't even allocate, let alone perform I/O.
221- len = @ccall strlen (ptr:: Cstring ):: Csize_t
222- old_cursor = log_cursor[]
223- new_cursor = old_cursor + len+ 1
224- if new_cursor >= MAX_LOG_BUFLEN
225- # overrun
226- return
227- end
217+ global log_buffer
218+ str = unsafe_string (ptr)
228219
229- @ccall memmove ((pointer (log_buffer)+ old_cursor):: Ptr{Nothing} ,
230- pointer (ptr):: Ptr{Nothing} , (len+ 1 ):: Csize_t ):: Nothing
231- log_cursor[] = new_cursor # the consumer handles CAS'ing this value
220+ # flush if we've started a new log message
221+ if startswith (str, r" [A-Z]!" )
222+ flush_log_messages ()
223+ end
232224
233- # avoid code that depends on the runtime (even the unsafe_convert from ccall does?!)
234- assume (isassigned (log_cond))
235- @ccall uv_async_send (log_cond[]. handle:: Ptr{Nothing} ):: Cint
225+ # append the lines to the buffer
226+ println (log_buffer, str)
236227
237228 return
238229end
239230
240- function _log_message (blob)
231+ function flush_log_messages ()
232+ global log_buffer
233+ message = String (take! (log_buffer))
234+ isempty (message) && return
235+
241236 # the message format isn't documented, but it looks like a message starts with a capital
242237 # and the severity (e.g. `I!`), and subsequent lines start with a lowercase mark (`!i`)
243- #
244- # lines are separated by a \0 if they came in separately, but there may also be multiple
245- # actual lines separated by \n in each message.
246- for message in split (blob, r" [\0\n ]+(?=[A-Z]!)" )
247- code = message[1 ]
248- lines = split (message[3 : end ], r" [\0\n ]+[a-z]!" )
249- submessage = join (lines, ' \n ' )
250- if code == ' I'
251- @debug submessage
252- elseif code == ' W'
253- @warn submessage
254- elseif code == ' E'
255- @error submessage
256- elseif code == ' F'
257- error (submessage)
258- else
259- @info " Unknown log message, please file an issue.\n $message "
260- end
238+ code = message[1 ]
239+ lines = split (message[3 : end ], r" \n +[a-z]!" )
240+ message = join (strip .(lines), ' \n ' )
241+ if code == ' I'
242+ @debug message
243+ elseif code == ' W'
244+ @warn message
245+ elseif code == ' E'
246+ @error message
247+ elseif code == ' F'
248+ error (message)
249+ else
250+ @info " Unknown log message, please file an issue.\n $message "
261251 end
262- return
263252end
264253
265254function __init__ ()
@@ -273,21 +262,9 @@ function __init__()
273262 # register a log callback
274263 if ! Sys. iswindows () && # NVIDIA bug #3321130 &&
275264 ! precompiling && (isdebug (:init , CUBLAS) || Base. JLOptions (). debug_level >= 2 )
276- log_cond[] = Base. AsyncCondition () do async_cond
277- blob = " "
278- while true
279- message_length = log_cursor[]
280- blob = unsafe_string (pointer (log_buffer), message_length)
281- if Threads. atomic_cas! (log_cursor, message_length, UInt (0 )) == message_length
282- break
283- end
284- end
285- _log_message (blob)
286- return
287- end
288-
289265 callback = @cfunction (log_message, Nothing, (Cstring,))
290266 cublasSetLoggerCallback (callback)
267+ atexit (flush_log_messages)
291268 end
292269end
293270
0 commit comments