Skip to content

Commit a5b02cf

Browse files
committed
merge os-log
[only benchmarks]
1 parent 5251b6c commit a5b02cf

File tree

7 files changed

+543
-1
lines changed

7 files changed

+543
-1
lines changed

lib/mtl/MTL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ include("events.jl")
3434
include("fences.jl")
3535
include("heap.jl")
3636
include("buffer.jl")
37+
include("log_state.jl")
3738
include("command_queue.jl")
3839
include("command_buf.jl")
3940
include("compute_pipeline.jl")

lib/mtl/command_queue.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,21 @@
1+
export MTLCommandQueueDescriptor
2+
3+
@objcwrapper immutable=false MTLCommandQueueDescriptor <: NSObject
4+
5+
@objcproperties MTLCommandQueueDescriptor begin
6+
@autoproperty maxCommandBufferCount::NSUInteger
7+
@autoproperty logState::id{MTLLogState} setter=setLogState
8+
end
9+
10+
function MTLCommandQueueDescriptor()
11+
handle = @objc [MTLCommandQueueDescriptor alloc]::id{MTLCommandQueueDescriptor}
12+
obj = MTLCommandQueueDescriptor(handle)
13+
finalizer(release, obj)
14+
@objc [obj::id{MTLCommandQueueDescriptor} init]::id{MTLCommandQueueDescriptor}
15+
return obj
16+
end
17+
18+
119
export MTLCommandQueue
220

321
@objcwrapper immutable=false MTLCommandQueue <: NSObject
@@ -13,3 +31,10 @@ function MTLCommandQueue(dev::MTLDevice)
1331
finalizer(release, obj)
1432
return obj
1533
end
34+
35+
function MTLCommandQueue(dev::MTLDevice, descriptor::MTLCommandQueueDescriptor)
36+
handle = @objc [dev::id{MTLDevice} newCommandQueueWithDescriptor:descriptor::id{MTLCommandQueueDescriptor}]::id{MTLCommandQueue}
37+
obj = MTLCommandQueue(handle)
38+
finalizer(release, obj)
39+
return obj
40+
end

lib/mtl/log_state.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
export MTLLogLevel
2+
3+
@cenum MTLLogLevel::NSInteger begin
4+
MTLLogLevelUndefined = 0
5+
MTLLogLevelDebug = 1
6+
MTLLogLevelInfo = 2
7+
MTLLogLevelNotice = 3
8+
MTLLogLevelError = 4
9+
MTLLogLevelFault = 5
10+
end
11+
12+
export MTLLogStateDescriptor
13+
14+
@objcwrapper immutable=false MTLLogStateDescriptor <: NSObject
15+
16+
@objcproperties MTLLogStateDescriptor begin
17+
@autoproperty level::MTLLogLevel setter=setLevel
18+
@autoproperty bufferSize::NSInteger setter=setBufferSize
19+
end
20+
21+
function MTLLogStateDescriptor()
22+
handle = @objc [MTLLogStateDescriptor alloc]::id{MTLLogStateDescriptor}
23+
obj = MTLLogStateDescriptor(handle)
24+
finalizer(release, obj)
25+
@objc [obj::id{MTLLogStateDescriptor} init]::id{MTLLogStateDescriptor}
26+
return obj
27+
end
28+
29+
30+
export MTLLogState
31+
32+
@objcwrapper MTLLogState <: NSObject
33+
34+
function MTLLogState(dev::MTLDevice, descriptor::MTLLogStateDescriptor)
35+
err = Ref{id{NSError}}(nil)
36+
handle = @objc [dev::id{MTLDevice} newLogStateWithDescriptor:descriptor::id{MTLLogStateDescriptor}
37+
error:err::Ptr{id{NSError}}]::id{MTLLogState}
38+
err[] == nil || throw(NSError(err[]))
39+
MTLLogState(handle)
40+
end

src/Metal.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ include("device/intrinsics/memory.jl")
3636
include("device/intrinsics/simd.jl")
3737
include("device/intrinsics/version.jl")
3838
include("device/intrinsics/atomics.jl")
39+
include("device/intrinsics/output.jl")
3940
include("device/quirks.jl")
4041

4142
# array essentials

src/device/intrinsics/output.jl

Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
const MTLLOG_SUBSYSTEM = "com.juliagpu.metal.jl"
2+
const MTLLOG_CATEGRORY = "mtlprintf"
3+
4+
const __METAL_OS_LOG_TYPE_DEBUG__ = Int32(2)
5+
const __METAL_OS_LOG_TYPE_INFO__ = Int32(1)
6+
const __METAL_OS_LOG_TYPE_DEFAULT__ = Int32(0)
7+
const __METAL_OS_LOG_TYPE_ERROR__ = Int32(16)
8+
const __METAL_OS_LOG_TYPE_FAULT__ = Int32(17)
9+
10+
const ALLOW_DOUBLE_META = "allowdouble"
11+
12+
export @mtlprintf
13+
14+
@generated function promote_c_argument(arg)
15+
# > When a function with a variable-length argument list is called, the variable
16+
# > arguments are passed using C's old ``default argument promotions.'' These say that
17+
# > types char and short int are automatically promoted to int, and type float is
18+
# > automatically promoted to double. Therefore, varargs functions will never receive
19+
# > arguments of type char, short int, or float.
20+
21+
if arg == Cchar || arg == Cshort
22+
return :(Cint(arg))
23+
else
24+
return :(arg)
25+
end
26+
end
27+
28+
@generated function tag_doubles(arg)
29+
@dispose ctx=Context() begin
30+
ret = arg == Cfloat ? Cdouble : arg
31+
T_arg = convert(LLVMType, arg)
32+
T_ret = convert(LLVMType, ret)
33+
34+
f, ft = create_function(T_ret, [T_arg])
35+
36+
@dispose builder=IRBuilder() begin
37+
entry = BasicBlock(f, "entry")
38+
position!(builder, entry)
39+
40+
p1 = parameters(f)[1]
41+
42+
if arg == Cfloat
43+
res = fpext!(builder, p1, LLVM.DoubleType())
44+
metadata(res)["ir_check_ignore"] = MDNode([])
45+
ret!(builder, res)
46+
else
47+
ret!(builder, p1)
48+
end
49+
end
50+
51+
call_function(f, ret, Tuple{arg}, :arg)
52+
end
53+
end
54+
55+
56+
"""
57+
@mtlprintf("%Fmt", args...)
58+
59+
Print a formatted string in device context on the host standard output.
60+
"""
61+
macro mtlprintf(fmt::String, args...)
62+
fmt_val = Val(Symbol(fmt))
63+
64+
return :(_mtlprintf($fmt_val, $(map(arg -> :(tag_doubles(promote_c_argument($arg))), esc.(args))...)))
65+
end
66+
67+
@generated function _mtlprintf(::Val{fmt}, argspec...) where {fmt}
68+
@dispose ctx=Context() begin
69+
arg_exprs = [:( argspec[$i] ) for i in 1:length(argspec)]
70+
arg_types = [argspec...]
71+
72+
T_void = LLVM.VoidType()
73+
T_int32 = LLVM.Int32Type()
74+
T_int64 = LLVM.Int64Type()
75+
T_pint8 = LLVM.PointerType(LLVM.Int8Type())
76+
T_pint8a2 = LLVM.PointerType(LLVM.Int8Type(), 2)
77+
78+
# create functions
79+
param_types = LLVMType[convert(LLVMType, typ) for typ in arg_types]
80+
llvm_f, llvm_ft = create_function(T_void, LLVMType[]; vararg=true)
81+
mod = LLVM.parent(llvm_f)
82+
83+
# generate IR
84+
@dispose builder=IRBuilder() begin
85+
entry = BasicBlock(llvm_f, "entry")
86+
position!(builder, entry)
87+
88+
str = globalstring_ptr!(builder, String(fmt), addrspace=2)
89+
90+
# compute argsize
91+
argtypes = LLVM.StructType(param_types)
92+
dl = datalayout(mod)
93+
arg_size = LLVM.ConstantInt(T_int64, sizeof(dl, argtypes))
94+
95+
alloc = alloca!(builder, T_pint8)
96+
buffer = bitcast!(builder, alloc, T_pint8)
97+
alloc_size = LLVM.ConstantInt(T_int64, sizeof(dl, T_pint8))
98+
99+
lifetime_start_fty = LLVM.FunctionType(T_void, [T_int64, T_pint8])
100+
lifetime_start = LLVM.Function(mod, "llvm.lifetime.start.p0i8", lifetime_start_fty)
101+
call!(builder, lifetime_start_fty, lifetime_start, [alloc_size, buffer])
102+
103+
va_start_fty = LLVM.FunctionType(T_void, [T_pint8])
104+
va_start = LLVM.Function(mod, "llvm.va_start", va_start_fty)
105+
call!(builder, va_start_fty, va_start, [buffer])
106+
107+
# invoke @air.os_log and return
108+
subsystem_str = globalstring_ptr!(builder, MTLLOG_SUBSYSTEM, addrspace=2)
109+
category_str = globalstring_ptr!(builder, MTLLOG_CATEGRORY, addrspace=2)
110+
log_type = LLVM.ConstantInt(T_int32, __METAL_OS_LOG_TYPE_DEBUG__)
111+
os_log_fty = LLVM.FunctionType(T_void, [T_pint8a2, T_pint8a2, T_int32, T_pint8a2, T_pint8, T_int64])
112+
os_log = LLVM.Function(mod, "air.os_log", os_log_fty)
113+
114+
arg_ptr = load!(builder, T_pint8, alloc)
115+
116+
call!(builder, os_log_fty, os_log, [subsystem_str, category_str, log_type, str, arg_ptr, arg_size])
117+
118+
va_end_fty = LLVM.FunctionType(T_void, [T_pint8])
119+
va_end = LLVM.Function(mod, "llvm.va_end", va_end_fty)
120+
call!(builder, va_end_fty, va_end, [buffer])
121+
122+
lifetime_end_fty = LLVM.FunctionType(T_void, [T_int64, T_pint8])
123+
lifetime_end = LLVM.Function(mod, "llvm.lifetime.end.p0i8", lifetime_end_fty)
124+
call!(builder, lifetime_end_fty, lifetime_end, [alloc_size, buffer])
125+
126+
ret!(builder)
127+
end
128+
129+
call_function(llvm_f, Nothing, Tuple{arg_types...}, arg_exprs...)
130+
end
131+
end
132+
133+
134+
## print-like functionality
135+
136+
export @mtlprint, @mtlprintln
137+
138+
# simple conversions, defining an expression and the resulting argument type. nothing fancy,
139+
# `@mtlprint` pretty directly maps to `@mtlprintf`; we should just support `write(::IO)`.
140+
const mtlprint_conversions = [
141+
Float32 => (x->:(Float64($x)), Float64),
142+
Ptr{<:Any} => (x->:(reinterpret(Int, $x)), Ptr{Cvoid}),
143+
LLVMPtr{<:Any} => (x->:(reinterpret(Int, $x)), Ptr{Cvoid}),
144+
Bool => (x->:(Int32($x)), Int32),
145+
]
146+
147+
# format specifiers
148+
const mtlprint_specifiers = Dict(
149+
# integers
150+
Int16 => "%hd",
151+
Int32 => "%d",
152+
Int64 => "%ld",
153+
UInt16 => "%hu",
154+
UInt32 => "%u",
155+
UInt64 => "%lu",
156+
157+
# floating-point
158+
Float32 => "%f",
159+
160+
# other
161+
Cchar => "%c",
162+
Ptr{Cvoid} => "%p",
163+
Cstring => "%s",
164+
)
165+
166+
@inline @generated function _mtlprint(parts...)
167+
fmt = ""
168+
args = Expr[]
169+
170+
for i in 1:length(parts)
171+
part = :(parts[$i])
172+
T = parts[i]
173+
174+
# put literals directly in the format string
175+
if T <: Val
176+
fmt *= string(T.parameters[1])
177+
continue
178+
end
179+
180+
# try to convert arguments if they are not supported directly
181+
if !haskey(mtlprint_specifiers, T)
182+
for (Tmatch, rule) in mtlprint_conversions
183+
if T <: Tmatch
184+
part = rule[1](part)
185+
T = rule[2]
186+
break
187+
end
188+
end
189+
end
190+
191+
# render the argument
192+
if haskey(mtlprint_specifiers, T)
193+
fmt *= mtlprint_specifiers[T]
194+
push!(args, part)
195+
elseif T <: Tuple
196+
fmt *= "("
197+
for (j, U) in enumerate(T.parameters)
198+
if haskey(mtlprint_specifiers, U)
199+
fmt *= mtlprint_specifiers[U]
200+
push!(args, :($part[$j]))
201+
if j < length(T.parameters)
202+
fmt *= ", "
203+
elseif length(T.parameters) == 1
204+
fmt *= ","
205+
end
206+
else
207+
@error("@mtlprint does not support values of type $U")
208+
end
209+
end
210+
fmt *= ")"
211+
elseif T <: String
212+
@error("@mtlprint does not support non-literal strings")
213+
elseif T <: Type
214+
fmt *= string(T.parameters[1])
215+
else
216+
@warn("@mtlprint does not support values of type $T")
217+
fmt *= "$(T)(...)"
218+
end
219+
end
220+
221+
quote
222+
@mtlprintf($fmt, $(args...))
223+
end
224+
end
225+
226+
"""
227+
@mtlprint(xs...)
228+
@mtlprintln(xs...)
229+
230+
Print a textual representation of values `xs` to standard output from the GPU. The
231+
functionality builds on `@mtlprintf`, and is intended as a more use friendly alternative of
232+
that API. However, that also means there's only limited support for argument types, handling
233+
16/32/64 signed and unsigned integers, 32 and 64-bit floating point numbers, `Cchar`s and
234+
pointers. For more complex output, use `@mtlprintf` directly.
235+
236+
Limited string interpolation is also possible:
237+
238+
```julia
239+
@mtlprint("Hello, World ", 42, "\\n")
240+
@mtlprint "Hello, World \$(42)\\n"
241+
```
242+
"""
243+
macro mtlprint(parts...)
244+
args = Union{Val,Expr,Symbol}[]
245+
246+
parts = [parts...]
247+
while true
248+
isempty(parts) && break
249+
250+
part = popfirst!(parts)
251+
252+
# handle string interpolation
253+
if isa(part, Expr) && part.head == :string
254+
parts = vcat(part.args, parts)
255+
continue
256+
end
257+
258+
# expose literals to the generator by using Val types
259+
if isbits(part) # literal numbers, etc
260+
push!(args, Val(part))
261+
elseif isa(part, QuoteNode) # literal symbols
262+
push!(args, Val(part.value))
263+
elseif isa(part, String) # literal strings need to be interned
264+
push!(args, Val(Symbol(part)))
265+
else # actual values that will be passed to printf
266+
push!(args, part)
267+
end
268+
end
269+
270+
quote
271+
_mtlprint($(map(esc, args)...))
272+
end
273+
end
274+
275+
@doc (@doc @mtlprint) ->
276+
macro mtlprintln(parts...)
277+
esc(quote
278+
Metal.@mtlprint($(parts...), "\n")
279+
end)
280+
end
281+
282+
export @mtlshow
283+
284+
"""
285+
@mtlshow(ex)
286+
287+
GPU analog of `Base.@show`. It comes with the same type restrictions as [`@mtlprintf`](@ref).
288+
289+
```julia
290+
@mtlshow thread_position_in_grid_1d()
291+
```
292+
"""
293+
macro mtlshow(exs...)
294+
blk = Expr(:block)
295+
for ex in exs
296+
push!(blk.args, :(Metal.@mtlprintln($(sprint(Base.show_unquoted,ex)*" = "),
297+
begin local value = $(esc(ex)) end)))
298+
end
299+
isempty(exs) || push!(blk.args, :value)
300+
blk
301+
end

0 commit comments

Comments
 (0)