Skip to content

Commit 0f52746

Browse files
committed
Implement @mtlprintf using os_log
1 parent df88b29 commit 0f52746

File tree

3 files changed

+286
-1
lines changed

3 files changed

+286
-1
lines changed

src/Metal.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ include("device/intrinsics/memory.jl")
3636
include("device/intrinsics/simd.jl")
3737
include("device/intrinsics/version.jl")
3838
include("device/intrinsics/atomics.jl")
39+
include("device/intrinsics/output.jl")
3940
include("device/quirks.jl")
4041

4142
# array essentials

src/device/intrinsics/output.jl

Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
export @mtlprintf
2+
3+
@generated function promote_c_argument(arg)
4+
# > When a function with a variable-length argument list is called, the variable
5+
# > arguments are passed using C's old ``default argument promotions.'' These say that
6+
# > types char and short int are automatically promoted to int, and type float is
7+
# > automatically promoted to double. Therefore, varargs functions will never receive
8+
# > arguments of type char, short int, or float.
9+
10+
if arg == Cchar || arg == Cshort
11+
return :(Cint(arg))
12+
elseif arg == Cfloat
13+
return :(Cdouble(arg))
14+
else
15+
return :(arg)
16+
end
17+
end
18+
19+
"""
20+
@mtlprintf("%Fmt", args...)
21+
22+
Print a formatted string in device context on the host standard output.
23+
"""
24+
macro mtlprintf(fmt::String, args...)
25+
fmt_val = Val(Symbol(fmt))
26+
27+
return :(_mtlprintf($fmt_val, $(map(arg -> :(promote_c_argument($arg)), esc.(args))...)))
28+
end
29+
30+
@generated function _mtlprintf(::Val{fmt}, argspec...) where {fmt}
31+
@dispose ctx=Context() begin
32+
arg_exprs = [:( argspec[$i] ) for i in 1:length(argspec)]
33+
arg_types = [argspec...]
34+
35+
T_void = LLVM.VoidType()
36+
T_int32 = LLVM.Int32Type()
37+
T_int64 = LLVM.Int64Type()
38+
T_pint8 = LLVM.PointerType(LLVM.Int8Type())
39+
T_pint8a2 = LLVM.PointerType(LLVM.Int8Type(), 2)
40+
41+
# create functions
42+
param_types = LLVMType[convert(LLVMType, typ) for typ in arg_types]
43+
llvm_f, llvm_ft = create_function(T_void, param_types)
44+
# push!(function_attributes(llvm_f), EnumAttribute("alwaysinline", 0))
45+
46+
mod = LLVM.parent(llvm_f)
47+
48+
# generate IR
49+
@dispose builder=IRBuilder() begin
50+
entry = BasicBlock(llvm_f, "entry")
51+
position!(builder, entry)
52+
53+
str = globalstring_ptr!(builder, String(fmt), addrspace=2)
54+
argsize = 0
55+
56+
# construct and fill args buffer
57+
if isempty(argspec)
58+
buffer = LLVM.PointerNull(T_pint8)
59+
else
60+
argtypes = LLVM.StructType("os_log_args")
61+
elements!(argtypes, param_types)
62+
63+
args = alloca!(builder, argtypes)
64+
for (i, param) in enumerate(parameters(llvm_f))
65+
p = struct_gep!(builder, argtypes, args, i-1)
66+
store!(builder, param, p)
67+
end
68+
69+
dl = datalayout(mod)
70+
argsize = sizeof(dl, argtypes)
71+
buffer = bitcast!(builder, args, T_pint8)
72+
end
73+
74+
# invoke @air.os_log and return
75+
subsystem_str = LLVM.PointerNull(T_pint8a2)
76+
intptr = LLVM.ConstantInt(T_int64, Int64(-1))
77+
category_str = const_inttoptr(intptr, T_pint8a2)
78+
log_type_default = LLVM.ConstantInt(T_int32, Int32(0))
79+
log_type_debug = LLVM.ConstantInt(T_int32, Int32(2))
80+
log_type_error = LLVM.ConstantInt(T_int32, Int32(16))
81+
arg_size = LLVM.ConstantInt(T_int64, Int64(argsize))
82+
83+
# declare void @air.os_log(i8 addrspace(2)*, i8 addrspace(2)*, i32, i8 addrspace(2)*, i8*, i64) local_unnamed_addr #4
84+
os_log_fty = LLVM.FunctionType(T_void, [T_pint8a2, T_pint8a2, T_int32, T_pint8a2, T_pint8, T_int64])
85+
os_log = LLVM.Function(mod, "air.os_log", os_log_fty)
86+
call!(builder, os_log_fty, os_log, [subsystem_str, category_str, log_type_error, str, buffer, arg_size])
87+
ret!(builder)
88+
end
89+
90+
call_function(llvm_f, Nothing, Tuple{arg_types...}, arg_exprs...)
91+
end
92+
end
93+
94+
95+
## print-like functionality
96+
97+
export @mtlprint, @mtlprintln
98+
99+
# simple conversions, defining an expression and the resulting argument type. nothing fancy,
100+
# `@mtlprint` pretty directly maps to `@mtlprintf`; we should just support `write(::IO)`.
101+
const mtlprint_conversions = [
102+
Float32 => (x->:(Float64($x)), Float64),
103+
Ptr{<:Any} => (x->:(reinterpret(Int, $x)), Ptr{Cvoid}),
104+
LLVMPtr{<:Any} => (x->:(reinterpret(Int, $x)), Ptr{Cvoid}),
105+
Bool => (x->:(Int32($x)), Int32),
106+
]
107+
108+
# format specifiers
109+
const mtlprint_specifiers = Dict(
110+
# integers
111+
Int16 => "%hd",
112+
Int32 => "%d",
113+
Int64 => "%ld",
114+
UInt16 => "%hu",
115+
UInt32 => "%u",
116+
UInt64 => "%lu",
117+
118+
# floating-point
119+
Float32 => "%f",
120+
121+
# other
122+
Cchar => "%c",
123+
Ptr{Cvoid} => "%p",
124+
Cstring => "%s",
125+
)
126+
127+
@inline @generated function _mtlprint(parts...)
128+
fmt = ""
129+
args = Expr[]
130+
131+
for i in 1:length(parts)
132+
part = :(parts[$i])
133+
T = parts[i]
134+
135+
# put literals directly in the format string
136+
if T <: Val
137+
fmt *= string(T.parameters[1])
138+
continue
139+
end
140+
141+
# try to convert arguments if they are not supported directly
142+
if !haskey(mtlprint_specifiers, T)
143+
for (Tmatch, rule) in mtlprint_conversions
144+
if T <: Tmatch
145+
part = rule[1](part)
146+
T = rule[2]
147+
break
148+
end
149+
end
150+
end
151+
152+
# render the argument
153+
if haskey(mtlprint_specifiers, T)
154+
fmt *= mtlprint_specifiers[T]
155+
push!(args, part)
156+
elseif T <: Tuple
157+
fmt *= "("
158+
for (j, U) in enumerate(T.parameters)
159+
if haskey(mtlprint_specifiers, U)
160+
fmt *= mtlprint_specifiers[U]
161+
push!(args, :($part[$j]))
162+
if j < length(T.parameters)
163+
fmt *= ", "
164+
elseif length(T.parameters) == 1
165+
fmt *= ","
166+
end
167+
else
168+
@error("@mtlprint does not support values of type $U")
169+
end
170+
end
171+
fmt *= ")"
172+
elseif T <: String
173+
@error("@mtlprint does not support non-literal strings")
174+
elseif T <: Type
175+
fmt *= string(T.parameters[1])
176+
else
177+
@warn("@mtlprint does not support values of type $T")
178+
fmt *= "$(T)(...)"
179+
end
180+
end
181+
182+
quote
183+
@mtlprintf($fmt, $(args...))
184+
end
185+
end
186+
187+
"""
188+
@mtlprint(xs...)
189+
@mtlprintln(xs...)
190+
191+
Print a textual representation of values `xs` to standard output from the GPU. The
192+
functionality builds on `@mtlprintf`, and is intended as a more use friendly alternative of
193+
that API. However, that also means there's only limited support for argument types, handling
194+
16/32/64 signed and unsigned integers, 32 and 64-bit floating point numbers, `Cchar`s and
195+
pointers. For more complex output, use `@mtlprintf` directly.
196+
197+
Limited string interpolation is also possible:
198+
199+
```julia
200+
@mtlprint("Hello, World ", 42, "\\n")
201+
@mtlprint "Hello, World \$(42)\\n"
202+
```
203+
"""
204+
macro mtlprint(parts...)
205+
args = Union{Val,Expr,Symbol}[]
206+
207+
parts = [parts...]
208+
while true
209+
isempty(parts) && break
210+
211+
part = popfirst!(parts)
212+
213+
# handle string interpolation
214+
if isa(part, Expr) && part.head == :string
215+
parts = vcat(part.args, parts)
216+
continue
217+
end
218+
219+
# expose literals to the generator by using Val types
220+
if isbits(part) # literal numbers, etc
221+
push!(args, Val(part))
222+
elseif isa(part, QuoteNode) # literal symbols
223+
push!(args, Val(part.value))
224+
elseif isa(part, String) # literal strings need to be interned
225+
push!(args, Val(Symbol(part)))
226+
else # actual values that will be passed to printf
227+
push!(args, part)
228+
end
229+
end
230+
231+
quote
232+
_mtlprint($(map(esc, args)...))
233+
end
234+
end
235+
236+
@doc (@doc @mtlprint) ->
237+
macro mtlprintln(parts...)
238+
esc(quote
239+
Metal.@mtlprint($(parts...), "\n")
240+
end)
241+
end
242+
243+
export @mtlshow
244+
245+
"""
246+
@mtlshow(ex)
247+
248+
GPU analog of `Base.@show`. It comes with the same type restrictions as [`@mtlprintf`](@ref).
249+
250+
```julia
251+
@mtlshow thread_position_in_grid_1d()
252+
```
253+
"""
254+
macro mtlshow(exs...)
255+
blk = Expr(:block)
256+
for ex in exs
257+
push!(blk.args, :(Metal.@mtlprintln($(sprint(Base.show_unquoted,ex)*" = "),
258+
begin local value = $(esc(ex)) end)))
259+
end
260+
isempty(exs) || push!(blk.args, :value)
261+
blk
262+
end

src/state.jl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,29 @@ function global_queue(dev::MTLDevice)
4747
@autoreleasepool begin
4848
# NOTE: MTLCommandQueue itself is manually reference-counted,
4949
# the release pool is for resources used during its construction.
50-
queue = MTLCommandQueue(dev)
50+
log_state_descriptor = MTLLogStateDescriptor()
51+
log_state = MTLLogState(dev, log_state_descriptor)
52+
53+
function log_handler(subSystem, category, logLevel, message)
54+
println(String(message))
55+
return nothing
56+
end
57+
function wrapper(subSystem, category, logLevel, message)
58+
log_handler(subSystem == nil ? nothing : NSString(subSystem),
59+
category == nil ? nothing : NSString(category),
60+
logLevel,
61+
message == nil ? nothing : NSString(message)
62+
)
63+
return nothing
64+
end
65+
66+
block = @objcblock(wrapper, Nothing, (id{NSString}, id{NSString}, MTLLogLevel, id{NSString}))
67+
68+
@objc [log_state::id{MTLLogState} addLogHandler:block::id{NSBlock}]::Nothing
69+
70+
queue_descriptor = MTLCommandQueueDescriptor()
71+
queue_descriptor.logState = log_state
72+
queue = MTLCommandQueue(dev, queue_descriptor)
5173
queue.label = "global_queue($(current_task()))"
5274
global_queues[queue] = nothing
5375
queue

0 commit comments

Comments
 (0)