Skip to content

Commit 5b6f208

Browse files
committed
Test more intrinsics and fix min/max
Also clean up the different tests
1 parent 6e8a75d commit 5b6f208

File tree

2 files changed

+200
-50
lines changed

2 files changed

+200
-50
lines changed

src/device/intrinsics/math.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@ using Base.Math: throw_complex_domainerror
1717
@device_override Base.abs(x::Float32) = ccall("extern air.fabs.f32", llvmcall, Cfloat, (Cfloat,), x)
1818
@device_override Base.abs(x::Float16) = ccall("extern air.fabs.f16", llvmcall, Float16, (Float16,), x)
1919

20-
@device_override FastMath.min_fast(x::Float32) = ccall("extern air.fast_fmin.f32", llvmcall, Cfloat, (Cfloat,), x)
21-
@device_override Base.min(x::Float32) = ccall("extern air.fmin.f32", llvmcall, Cfloat, (Cfloat,), x)
22-
@device_override Base.min(x::Float16) = ccall("extern air.fmin.f16", llvmcall, Float16, (Float16,), x)
20+
@device_override FastMath.min_fast(x::Float32, y::Float32) = ccall("extern air.fast_fmin.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
21+
@device_override Base.min(x::Float32, y::Float32) = ccall("extern air.fmin.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
22+
@device_override Base.min(x::Float16, y::Float16) = ccall("extern air.fmin.f16", llvmcall, Float16, (Float16, Float16), x, y)
2323

24-
@device_override FastMath.max_fast(x::Float32) = ccall("extern air.fast_fmax.f32", llvmcall, Cfloat, (Cfloat,), x)
25-
@device_override Base.max(x::Float32) = ccall("extern air.fmax.f32", llvmcall, Cfloat, (Cfloat,), x)
26-
@device_override Base.max(x::Float16) = ccall("extern air.fmax.f16", llvmcall, Float16, (Float16,), x)
24+
@device_override FastMath.max_fast(x::Float32, y::Float32) = ccall("extern air.fast_fmax.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
25+
@device_override Base.max(x::Float32, y::Float32) = ccall("extern air.fmax.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
26+
@device_override Base.max(x::Float16, y::Float16) = ccall("extern air.fmax.f16", llvmcall, Float16, (Float16, Float16), x, y)
2727

2828
@device_override FastMath.acos_fast(x::Float32) = ccall("extern air.fast_acos.f32", llvmcall, Cfloat, (Cfloat,), x)
2929
@device_override Base.acos(x::Float32) = ccall("extern air.acos.f32", llvmcall, Cfloat, (Cfloat,), x)
@@ -240,6 +240,7 @@ end
240240
s = ccall("extern air.sincos.f32", llvmcall, Cfloat, (Cfloat, Ptr{Cfloat}), x, c)
241241
(s, c[])
242242
end
243+
# XXX: Broken
243244
@device_override function Base.sincos(x::Float16)
244245
c = Ref{Float16}()
245246
s = ccall("extern air.sincos.f16", llvmcall, Float16, (Float16, Ptr{Float16}), x, c)

test/device/intrinsics.jl

Lines changed: 193 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
using SpecialFunctions
21
using Metal: metal_support
2+
using Random
3+
using SpecialFunctions
34

45
@testset "arguments" begin
56
@on_device dispatch_quadgroups_per_threadgroup()
@@ -103,71 +104,219 @@ end
103104

104105
############################################################################################
105106

107+
MATH_INTR_FUNCS_1_ARG = [
108+
abs,
109+
acos,
110+
# acosh, # not defined for values < 1, tested separately
111+
asin,
112+
asinh,
113+
atan,
114+
atanh,
115+
ceil,
116+
cos,
117+
cosh,
118+
cospi,
119+
exp,
120+
exp2,
121+
exp10,
122+
floor,
123+
Metal.fract,
124+
log,
125+
log2,
126+
log10,
127+
# Metal.rint, # not sure what the behaviour actually is
128+
round,
129+
Metal.rsqrt,
130+
sin,
131+
sinh,
132+
sinpi,
133+
sqrt,
134+
tan,
135+
tanh,
136+
tanpi,
137+
trunc,
138+
]
139+
Metal.rsqrt(x::Float16) = 1 / sqrt(x)
140+
Metal.rsqrt(x::Float32) = 1 / sqrt(x)
141+
Metal.fract(x::Float16) = mod(x, 1)
142+
Metal.fract(x::Float32) = mod(x, 1)
143+
144+
MATH_INTR_FUNCS_2_ARG = [
145+
min,
146+
max,
147+
pow, # :(^),
148+
Metal.powr,
149+
hypot,
150+
]
151+
152+
MATH_INTR_FUNCS_3_ARG = [
153+
fma,
154+
]
155+
106156
@testset "math" begin
107-
a = ones(Float32,1)
108-
a .* Float32(3.14)
109-
bufferA = MtlArray{eltype(a),length(size(a)),Metal.SharedStorage}(a)
110-
vecA = unsafe_wrap(Vector{Float32}, pointer(bufferA), 1)
157+
# 1-arg functions
158+
@testset "$(fun)()::$T" for fun in MATH_INTR_FUNCS_1_ARG, T in (Float32, Float16)
159+
cpuarr = if fun in [log, log2, log10, Metal.rsqrt, sqrt]
160+
rand(T, 4)
161+
else
162+
T[0.0, -0.0, rand(T), -rand(T)]
163+
end
164+
165+
mtlarr = MtlArray(cpuarr)
111166

112-
function intr_test(arr)
167+
mtlout = fill!(similar(mtlarr), 0)
168+
169+
function kernel(res, arr)
113170
idx = thread_position_in_grid_1d()
114-
arr[idx] = cos(arr[idx])
171+
res[idx] = fun(arr[idx])
115172
return nothing
116173
end
117-
@metal intr_test(bufferA)
118-
synchronize()
119-
@test vecA cos.(a)
174+
Metal.@sync @metal threads = length(mtlout) kernel(mtlout, mtlarr)
175+
@eval @test Array($mtlout) $fun.($cpuarr)
176+
end
177+
# 2-arg functions
178+
@testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_2_ARG
179+
N = 4
180+
arr1 = randn(T, N)
181+
arr2 = randn(T, N)
182+
mtlarr1 = MtlArray(arr1)
183+
mtlarr2 = MtlArray(arr2)
184+
185+
mtlout = fill!(similar(mtlarr1), 0)
120186

121-
function intr_test2(arr)
187+
function kernel(res, x, y)
122188
idx = thread_position_in_grid_1d()
123-
arr[idx] = Metal.rsqrt(arr[idx])
189+
res[idx] = fun(x[idx], y[idx])
124190
return nothing
125191
end
126-
@metal intr_test2(bufferA)
127-
synchronize()
192+
Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
193+
@eval @test Array($mtlout) $fun.($arr1, $arr2)
194+
end
195+
# 3-arg functions
196+
@testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_3_ARG
197+
N = 4
198+
arr1 = randn(T, N)
199+
arr2 = randn(T, N)
200+
arr3 = randn(T, N)
201+
202+
mtlarr1 = MtlArray(arr1)
203+
mtlarr2 = MtlArray(arr2)
204+
mtlarr3 = MtlArray(arr3)
128205

129-
bufferB = MtlArray{eltype(a),length(size(a)),Metal.SharedStorage}(a)
130-
vecB = unsafe_wrap(Vector{Float32}, pointer(bufferB), 1)
206+
mtlout = fill!(similar(mtlarr1), 0)
131207

132-
function intr_test3(arr_sin, arr_cos)
208+
function kernel(res, x, y, z)
133209
idx = thread_position_in_grid_1d()
134-
s, c = sincos(arr_cos[idx])
135-
arr_sin[idx] = s
136-
arr_cos[idx] = c
210+
res[idx] = fun(x[idx], y[idx], z[idx])
137211
return nothing
138212
end
213+
Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2, mtlarr3)
214+
@eval @test Array($mtlout) $fun.($arr1, $arr2, $arr3)
215+
end
216+
end
139217

140-
@metal intr_test3(bufferA, bufferB)
141-
synchronize()
142-
@test vecA sin.(a)
143-
@test vecB cos.(a)
218+
@testset "unique math" begin
219+
@testset "$T" for T in (Float32, Float16)
220+
let # acosh
221+
arr = T[0, rand(T, 3)...] .+ T(1)
222+
buffer = MtlArray(arr)
223+
vec = acosh.(buffer)
224+
@test Array(vec) acosh.(arr)
225+
end
144226

145-
b = collect(LinRange(nextfloat(-1f0), 10f0, 20))
146-
bufferC = MtlArray(b)
147-
vecC = Array(log1p.(bufferC))
148-
@test vecC log1p.(b)
227+
let # sincos
228+
N = 4
229+
arr = rand(T, N)
230+
bufferA = MtlArray(arr)
231+
bufferB = MtlArray(arr)
232+
function intr_test3(arr_sin, arr_cos)
233+
idx = thread_position_in_grid_1d()
234+
sinres, cosres = sincos(arr_cos[idx])
235+
arr_sin[idx] = sinres
236+
arr_cos[idx] = cosres
237+
return nothing
238+
end
239+
# Broken with Float16
240+
if T == Float16
241+
@test_broken Metal.@sync @metal threads = N intr_test3(bufferA, bufferB)
242+
else
243+
Metal.@sync @metal threads = N intr_test3(bufferA, bufferB)
244+
@test Array(bufferA) sin.(arr)
245+
@test Array(bufferB) cos.(arr)
246+
end
247+
end
149248

249+
let #pow
250+
N = 4
251+
arr1 = rand(T, N)
252+
arr2 = rand(T, N)
253+
mtlarr1 = MtlArray(arr1)
254+
mtlarr2 = MtlArray(arr2)
150255

151-
d = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
152-
bufferD = MtlArray(d)
153-
vecD = Array(SpecialFunctions.erf.(bufferD))
154-
@test vecD SpecialFunctions.erf.(d)
256+
mtlout = fill!(similar(mtlarr1), 0)
155257

258+
function kernel(res, x, y)
259+
idx = thread_position_in_grid_1d()
260+
res[idx] = x[idx]^y[idx]
261+
return nothing
262+
end
263+
Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
264+
@test Array(mtlout) arr1 .^ arr2
265+
end
266+
267+
let #powr
268+
N = 4
269+
arr1 = rand(T, N)
270+
arr2 = rand(T, N)
271+
mtlarr1 = MtlArray(arr1)
272+
mtlarr2 = MtlArray(arr2)
273+
274+
mtlout = fill!(similar(mtlarr1), 0)
275+
276+
function kernel(res, x, y)
277+
idx = thread_position_in_grid_1d()
278+
res[idx] = Metal.powr(x[idx], y[idx])
279+
return nothing
280+
end
281+
Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
282+
@test Array(mtlout) arr1 .^ arr2
283+
end
156284

157-
e = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
158-
bufferE = MtlArray(e)
159-
vecE = Array(SpecialFunctions.erfc.(bufferE))
160-
@test vecE SpecialFunctions.erfc.(e)
285+
let # log1p
286+
arr = collect(LinRange(nextfloat(-1.0f0), 10.0f0, 20))
287+
buffer = MtlArray(arr)
288+
vec = Array(log1p.(buffer))
289+
@test vec log1p.(arr)
290+
end
161291

162-
f = collect(LinRange(-1f0, 1f0, 20))
163-
bufferF = MtlArray(f)
164-
vecF = Array(SpecialFunctions.erfinv.(bufferF))
165-
@test vecF SpecialFunctions.erfinv.(f)
292+
let # erf
293+
arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
294+
buffer = MtlArray(arr)
295+
vec = Array(SpecialFunctions.erf.(buffer))
296+
@test vec SpecialFunctions.erf.(arr)
297+
end
298+
299+
let # erfc
300+
arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
301+
buffer = MtlArray(arr)
302+
vec = Array(SpecialFunctions.erfc.(buffer))
303+
@test vec SpecialFunctions.erfc.(arr)
304+
end
166305

167-
f = collect(LinRange(nextfloat(-88f0), 88f0, 100))
168-
bufferF = MtlArray(f)
169-
vecF = Array(expm1.(bufferF))
170-
@test vecF expm1.(f)
306+
let # erfinv
307+
arr = collect(LinRange(-1.0f0, 1.0f0, 20))
308+
buffer = MtlArray(arr)
309+
vec = Array(SpecialFunctions.erfinv.(buffer))
310+
@test vec SpecialFunctions.erfinv.(arr)
311+
end
312+
313+
let # expm1
314+
arr = collect(LinRange(nextfloat(-88.0f0), 88.0f0, 100))
315+
buffer = MtlArray(arr)
316+
vec = Array(expm1.(buffer))
317+
@test vec expm1.(arr)
318+
end
319+
end
171320
end
172321

173322
############################################################################################

0 commit comments

Comments
 (0)