Skip to content

Commit a0b63f1

Browse files
committed
tag doubles
1 parent 48c48c7 commit a0b63f1

File tree

1 file changed

+34
-4
lines changed

1 file changed

+34
-4
lines changed

src/device/intrinsics/output.jl

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ const __METAL_OS_LOG_TYPE_DEFAULT__ = Int32(0)
77
const __METAL_OS_LOG_TYPE_ERROR__ = Int32(16)
88
const __METAL_OS_LOG_TYPE_FAULT__ = Int32(17)
99

10+
const ALLOW_DOUBLE_META = "allowdouble"
11+
1012
export @mtlprintf
1113

1214
@generated function promote_c_argument(arg)
@@ -18,13 +20,38 @@ export @mtlprintf
1820

1921
if arg == Cchar || arg == Cshort
2022
return :(Cint(arg))
21-
elseif arg == Cfloat
22-
return :(Cdouble(arg))
2323
else
2424
return :(arg)
2525
end
2626
end
2727

28+
@generated function tag_doubles(arg)
29+
@dispose ctx=Context() begin
30+
ret = arg == Cfloat ? Culong : arg
31+
T_arg = convert(LLVMType, arg)
32+
T_ret = convert(LLVMType, ret)
33+
34+
f, ft = create_function(T_ret, [T_arg])
35+
36+
@dispose builder=IRBuilder() begin
37+
entry = BasicBlock(f, "entry")
38+
position!(builder, entry)
39+
40+
p1 = parameters(f)[1]
41+
42+
if arg == Cfloat
43+
res = fpext!(builder, p1, LLVM.DoubleType())
44+
metadata(res)["annotation"] = MDNode([MDString(ALLOW_DOUBLE_META)])
45+
ret!(builder, res)
46+
else
47+
ret!(builder, p1)
48+
end
49+
end
50+
51+
call_function(f, ret, Tuple{arg}, :arg)
52+
end
53+
end
54+
2855
"""
2956
@mtlprintf("%Fmt", args...)
3057
@@ -33,7 +60,7 @@ Print a formatted string in device context on the host standard output.
3360
macro mtlprintf(fmt::String, args...)
3461
fmt_val = Val(Symbol(fmt))
3562

36-
return :(_mtlprintf($fmt_val, $(map(arg -> :(promote_c_argument($arg)), esc.(args))...)))
63+
return :(_mtlprintf($fmt_val, $(map(arg -> :(tag_doubles(promote_c_argument($arg))), esc.(args))...)))
3764
end
3865

3966
@generated function _mtlprintf(::Val{fmt}, argspec...) where {fmt}
@@ -72,7 +99,10 @@ end
7299
args = alloca!(builder, argtypes)
73100
for (i, param) in enumerate(parameters(llvm_f))
74101
p = struct_gep!(builder, argtypes, args, i-1)
75-
store!(builder, param, p)
102+
st = store!(builder, param, p)
103+
if param_types[i] == LLVM.DoubleType()
104+
metadata(st)["annotation"] = MDNode([MDString(ALLOW_DOUBLE_META)])
105+
end
76106
end
77107

78108
dl = datalayout(mod)

0 commit comments

Comments
 (0)