@@ -210,56 +210,45 @@ end
210
210
211
211
# # logging
212
212
213
- const MAX_LOG_BUFLEN = UInt (1024 * 1024 )
214
- const log_buffer = Vector {UInt8} (undef, MAX_LOG_BUFLEN)
215
- const log_cursor = Threads. Atomic {UInt} (0 )
216
- const log_cond = Ref {Base.AsyncCondition} () # root
213
+ # CUBLAS calls the log callback multiple times for each message, so we need to buffer them
214
+ const log_buffer = IOBuffer ()
217
215
218
216
function log_message (ptr)
219
- # NOTE: this function may be called from unmanaged threads (by cublasXt),
220
- # so we can't even allocate, let alone perform I/O.
221
- len = @ccall strlen (ptr:: Cstring ):: Csize_t
222
- old_cursor = log_cursor[]
223
- new_cursor = old_cursor + len+ 1
224
- if new_cursor >= MAX_LOG_BUFLEN
225
- # overrun
226
- return
227
- end
217
+ global log_buffer
218
+ str = unsafe_string (ptr)
228
219
229
- @ccall memmove ((pointer (log_buffer)+ old_cursor):: Ptr{Nothing} ,
230
- pointer (ptr):: Ptr{Nothing} , (len+ 1 ):: Csize_t ):: Nothing
231
- log_cursor[] = new_cursor # the consumer handles CAS'ing this value
220
+ # flush if we've started a new log message
221
+ if startswith (str, r" [A-Z]!" )
222
+ flush_log_messages ()
223
+ end
232
224
233
- # avoid code that depends on the runtime (even the unsafe_convert from ccall does?!)
234
- assume (isassigned (log_cond))
235
- @ccall uv_async_send (log_cond[]. handle:: Ptr{Nothing} ):: Cint
225
+ # append the lines to the buffer
226
+ println (log_buffer, str)
236
227
237
228
return
238
229
end
239
230
240
- function _log_message (blob)
231
+ function flush_log_messages ()
232
+ global log_buffer
233
+ message = String (take! (log_buffer))
234
+ isempty (message) && return
235
+
241
236
# the message format isn't documented, but it looks like a message starts with a capital
242
237
# and the severity (e.g. `I!`), and subsequent lines start with a lowercase mark (`!i`)
243
- #
244
- # lines are separated by a \0 if they came in separately, but there may also be multiple
245
- # actual lines separated by \n in each message.
246
- for message in split (blob, r" [\0\n ]+(?=[A-Z]!)" )
247
- code = message[1 ]
248
- lines = split (message[3 : end ], r" [\0\n ]+[a-z]!" )
249
- submessage = join (lines, ' \n ' )
250
- if code == ' I'
251
- @debug submessage
252
- elseif code == ' W'
253
- @warn submessage
254
- elseif code == ' E'
255
- @error submessage
256
- elseif code == ' F'
257
- error (submessage)
258
- else
259
- @info " Unknown log message, please file an issue.\n $message "
260
- end
238
+ code = message[1 ]
239
+ lines = split (message[3 : end ], r" \n +[a-z]!" )
240
+ message = join (strip .(lines), ' \n ' )
241
+ if code == ' I'
242
+ @debug message
243
+ elseif code == ' W'
244
+ @warn message
245
+ elseif code == ' E'
246
+ @error message
247
+ elseif code == ' F'
248
+ error (message)
249
+ else
250
+ @info " Unknown log message, please file an issue.\n $message "
261
251
end
262
- return
263
252
end
264
253
265
254
function __init__ ()
@@ -273,21 +262,9 @@ function __init__()
273
262
# register a log callback
274
263
if ! Sys. iswindows () && # NVIDIA bug #3321130 &&
275
264
! precompiling && (isdebug (:init , CUBLAS) || Base. JLOptions (). debug_level >= 2 )
276
- log_cond[] = Base. AsyncCondition () do async_cond
277
- blob = " "
278
- while true
279
- message_length = log_cursor[]
280
- blob = unsafe_string (pointer (log_buffer), message_length)
281
- if Threads. atomic_cas! (log_cursor, message_length, UInt (0 )) == message_length
282
- break
283
- end
284
- end
285
- _log_message (blob)
286
- return
287
- end
288
-
289
265
callback = @cfunction (log_message, Nothing, (Cstring,))
290
266
cublasSetLoggerCallback (callback)
267
+ atexit (flush_log_messages)
291
268
end
292
269
end
293
270
0 commit comments