|
| 1 | +export @mtlprintf |
| 2 | + |
| 3 | +@generated function promote_c_argument(arg) |
| 4 | + # > When a function with a variable-length argument list is called, the variable |
| 5 | + # > arguments are passed using C's old ``default argument promotions.'' These say that |
| 6 | + # > types char and short int are automatically promoted to int, and type float is |
| 7 | + # > automatically promoted to double. Therefore, varargs functions will never receive |
| 8 | + # > arguments of type char, short int, or float. |
| 9 | + |
| 10 | + if arg == Cchar || arg == Cshort |
| 11 | + return :(Cint(arg)) |
| 12 | + elseif arg == Cfloat |
| 13 | + return :(Cdouble(arg)) |
| 14 | + else |
| 15 | + return :(arg) |
| 16 | + end |
| 17 | +end |
| 18 | + |
| 19 | +""" |
| 20 | + @mtlprintf("%Fmt", args...) |
| 21 | +
|
| 22 | +Print a formatted string in device context on the host standard output. |
| 23 | +""" |
| 24 | +macro mtlprintf(fmt::String, args...) |
| 25 | + fmt_val = Val(Symbol(fmt)) |
| 26 | + |
| 27 | + return :(_mtlprintf($fmt_val, $(map(arg -> :(promote_c_argument($arg)), esc.(args))...))) |
| 28 | +end |
| 29 | + |
| 30 | +@generated function _mtlprintf(::Val{fmt}, argspec...) where {fmt} |
| 31 | + @dispose ctx=Context() begin |
| 32 | + arg_exprs = [:( argspec[$i] ) for i in 1:length(argspec)] |
| 33 | + arg_types = [argspec...] |
| 34 | + |
| 35 | + T_void = LLVM.VoidType() |
| 36 | + T_int32 = LLVM.Int32Type() |
| 37 | + T_int64 = LLVM.Int64Type() |
| 38 | + T_pint8 = LLVM.PointerType(LLVM.Int8Type()) |
| 39 | + T_pint8a2 = LLVM.PointerType(LLVM.Int8Type(), 2) |
| 40 | + |
| 41 | + # create functions |
| 42 | + param_types = LLVMType[convert(LLVMType, typ) for typ in arg_types] |
| 43 | + llvm_f, llvm_ft = create_function(T_void, param_types) |
| 44 | + # push!(function_attributes(llvm_f), EnumAttribute("alwaysinline", 0)) |
| 45 | + |
| 46 | + mod = LLVM.parent(llvm_f) |
| 47 | + |
| 48 | + # generate IR |
| 49 | + @dispose builder=IRBuilder() begin |
| 50 | + entry = BasicBlock(llvm_f, "entry") |
| 51 | + position!(builder, entry) |
| 52 | + |
| 53 | + str = globalstring_ptr!(builder, String(fmt), addrspace=2) |
| 54 | + argsize = 0 |
| 55 | + |
| 56 | + # construct and fill args buffer |
| 57 | + if isempty(argspec) |
| 58 | + buffer = LLVM.PointerNull(T_pint8) |
| 59 | + else |
| 60 | + argtypes = LLVM.StructType("os_log_args") |
| 61 | + elements!(argtypes, param_types) |
| 62 | + |
| 63 | + args = alloca!(builder, argtypes) |
| 64 | + for (i, param) in enumerate(parameters(llvm_f)) |
| 65 | + p = struct_gep!(builder, argtypes, args, i-1) |
| 66 | + store!(builder, param, p) |
| 67 | + end |
| 68 | + |
| 69 | + dl = datalayout(mod) |
| 70 | + argsize = sizeof(dl, argtypes) |
| 71 | + buffer = bitcast!(builder, args, T_pint8) |
| 72 | + end |
| 73 | + |
| 74 | + # invoke @air.os_log and return |
| 75 | + subsystem_str = LLVM.PointerNull(T_pint8a2) |
| 76 | + intptr = LLVM.ConstantInt(T_int64, Int64(-1)) |
| 77 | + category_str = const_inttoptr(intptr, T_pint8a2) |
| 78 | + log_type_default = LLVM.ConstantInt(T_int32, Int32(0)) |
| 79 | + log_type_debug = LLVM.ConstantInt(T_int32, Int32(2)) |
| 80 | + log_type_error = LLVM.ConstantInt(T_int32, Int32(16)) |
| 81 | + arg_size = LLVM.ConstantInt(T_int64, Int64(argsize)) |
| 82 | + |
| 83 | + # declare void @air.os_log(i8 addrspace(2)*, i8 addrspace(2)*, i32, i8 addrspace(2)*, i8*, i64) local_unnamed_addr #4 |
| 84 | + os_log_fty = LLVM.FunctionType(T_void, [T_pint8a2, T_pint8a2, T_int32, T_pint8a2, T_pint8, T_int64]) |
| 85 | + os_log = LLVM.Function(mod, "air.os_log", os_log_fty) |
| 86 | + call!(builder, os_log_fty, os_log, [subsystem_str, category_str, log_type_error, str, buffer, arg_size]) |
| 87 | + ret!(builder) |
| 88 | + end |
| 89 | + |
| 90 | + call_function(llvm_f, Nothing, Tuple{arg_types...}, arg_exprs...) |
| 91 | + end |
| 92 | +end |
| 93 | + |
| 94 | + |
| 95 | +## print-like functionality |
| 96 | + |
| 97 | +export @mtlprint, @mtlprintln |
| 98 | + |
| 99 | +# simple conversions, defining an expression and the resulting argument type. nothing fancy, |
| 100 | +# `@mtlprint` pretty directly maps to `@mtlprintf`; we should just support `write(::IO)`. |
| 101 | +const mtlprint_conversions = [ |
| 102 | + Float32 => (x->:(Float64($x)), Float64), |
| 103 | + Ptr{<:Any} => (x->:(reinterpret(Int, $x)), Ptr{Cvoid}), |
| 104 | + LLVMPtr{<:Any} => (x->:(reinterpret(Int, $x)), Ptr{Cvoid}), |
| 105 | + Bool => (x->:(Int32($x)), Int32), |
| 106 | +] |
| 107 | + |
| 108 | +# format specifiers |
| 109 | +const mtlprint_specifiers = Dict( |
| 110 | + # integers |
| 111 | + Int16 => "%hd", |
| 112 | + Int32 => "%d", |
| 113 | + Int64 => "%ld", |
| 114 | + UInt16 => "%hu", |
| 115 | + UInt32 => "%u", |
| 116 | + UInt64 => "%lu", |
| 117 | + |
| 118 | + # floating-point |
| 119 | + Float32 => "%f", |
| 120 | + |
| 121 | + # other |
| 122 | + Cchar => "%c", |
| 123 | + Ptr{Cvoid} => "%p", |
| 124 | + Cstring => "%s", |
| 125 | +) |
| 126 | + |
| 127 | +@inline @generated function _mtlprint(parts...) |
| 128 | + fmt = "" |
| 129 | + args = Expr[] |
| 130 | + |
| 131 | + for i in 1:length(parts) |
| 132 | + part = :(parts[$i]) |
| 133 | + T = parts[i] |
| 134 | + |
| 135 | + # put literals directly in the format string |
| 136 | + if T <: Val |
| 137 | + fmt *= string(T.parameters[1]) |
| 138 | + continue |
| 139 | + end |
| 140 | + |
| 141 | + # try to convert arguments if they are not supported directly |
| 142 | + if !haskey(mtlprint_specifiers, T) |
| 143 | + for (Tmatch, rule) in mtlprint_conversions |
| 144 | + if T <: Tmatch |
| 145 | + part = rule[1](part) |
| 146 | + T = rule[2] |
| 147 | + break |
| 148 | + end |
| 149 | + end |
| 150 | + end |
| 151 | + |
| 152 | + # render the argument |
| 153 | + if haskey(mtlprint_specifiers, T) |
| 154 | + fmt *= mtlprint_specifiers[T] |
| 155 | + push!(args, part) |
| 156 | + elseif T <: Tuple |
| 157 | + fmt *= "(" |
| 158 | + for (j, U) in enumerate(T.parameters) |
| 159 | + if haskey(mtlprint_specifiers, U) |
| 160 | + fmt *= mtlprint_specifiers[U] |
| 161 | + push!(args, :($part[$j])) |
| 162 | + if j < length(T.parameters) |
| 163 | + fmt *= ", " |
| 164 | + elseif length(T.parameters) == 1 |
| 165 | + fmt *= "," |
| 166 | + end |
| 167 | + else |
| 168 | + @error("@mtlprint does not support values of type $U") |
| 169 | + end |
| 170 | + end |
| 171 | + fmt *= ")" |
| 172 | + elseif T <: String |
| 173 | + @error("@mtlprint does not support non-literal strings") |
| 174 | + elseif T <: Type |
| 175 | + fmt *= string(T.parameters[1]) |
| 176 | + else |
| 177 | + @warn("@mtlprint does not support values of type $T") |
| 178 | + fmt *= "$(T)(...)" |
| 179 | + end |
| 180 | + end |
| 181 | + |
| 182 | + quote |
| 183 | + @mtlprintf($fmt, $(args...)) |
| 184 | + end |
| 185 | +end |
| 186 | + |
| 187 | +""" |
| 188 | + @mtlprint(xs...) |
| 189 | + @mtlprintln(xs...) |
| 190 | +
|
| 191 | +Print a textual representation of values `xs` to standard output from the GPU. The |
| 192 | +functionality builds on `@mtlprintf`, and is intended as a more use friendly alternative of |
| 193 | +that API. However, that also means there's only limited support for argument types, handling |
| 194 | +16/32/64 signed and unsigned integers, 32 and 64-bit floating point numbers, `Cchar`s and |
| 195 | +pointers. For more complex output, use `@mtlprintf` directly. |
| 196 | +
|
| 197 | +Limited string interpolation is also possible: |
| 198 | +
|
| 199 | +```julia |
| 200 | + @mtlprint("Hello, World ", 42, "\\n") |
| 201 | + @mtlprint "Hello, World \$(42)\\n" |
| 202 | +``` |
| 203 | +""" |
| 204 | +macro mtlprint(parts...) |
| 205 | + args = Union{Val,Expr,Symbol}[] |
| 206 | + |
| 207 | + parts = [parts...] |
| 208 | + while true |
| 209 | + isempty(parts) && break |
| 210 | + |
| 211 | + part = popfirst!(parts) |
| 212 | + |
| 213 | + # handle string interpolation |
| 214 | + if isa(part, Expr) && part.head == :string |
| 215 | + parts = vcat(part.args, parts) |
| 216 | + continue |
| 217 | + end |
| 218 | + |
| 219 | + # expose literals to the generator by using Val types |
| 220 | + if isbits(part) # literal numbers, etc |
| 221 | + push!(args, Val(part)) |
| 222 | + elseif isa(part, QuoteNode) # literal symbols |
| 223 | + push!(args, Val(part.value)) |
| 224 | + elseif isa(part, String) # literal strings need to be interned |
| 225 | + push!(args, Val(Symbol(part))) |
| 226 | + else # actual values that will be passed to printf |
| 227 | + push!(args, part) |
| 228 | + end |
| 229 | + end |
| 230 | + |
| 231 | + quote |
| 232 | + _mtlprint($(map(esc, args)...)) |
| 233 | + end |
| 234 | +end |
| 235 | + |
| 236 | +@doc (@doc @mtlprint) -> |
| 237 | +macro mtlprintln(parts...) |
| 238 | + esc(quote |
| 239 | + Metal.@mtlprint($(parts...), "\n") |
| 240 | + end) |
| 241 | +end |
| 242 | + |
| 243 | +export @mtlshow |
| 244 | + |
| 245 | +""" |
| 246 | + @mtlshow(ex) |
| 247 | +
|
| 248 | +GPU analog of `Base.@show`. It comes with the same type restrictions as [`@mtlprintf`](@ref). |
| 249 | +
|
| 250 | +```julia |
| 251 | +@mtlshow thread_position_in_grid_1d() |
| 252 | +``` |
| 253 | +""" |
| 254 | +macro mtlshow(exs...) |
| 255 | + blk = Expr(:block) |
| 256 | + for ex in exs |
| 257 | + push!(blk.args, :(Metal.@mtlprintln($(sprint(Base.show_unquoted,ex)*" = "), |
| 258 | + begin local value = $(esc(ex)) end))) |
| 259 | + end |
| 260 | + isempty(exs) || push!(blk.args, :value) |
| 261 | + blk |
| 262 | +end |
0 commit comments