Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions src/device/intrinsics/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ using Base.Math: throw_complex_domainerror
# - add support for vector types
# - consider emitting LLVM intrinsics and lowering those in the back-end

### Common Intrinsics
@device_function clamp_fast(x::Float32, minval::Float32, maxval::Float32) = ccall("extern air.fast_clamp.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, minval, maxval)
@device_override Base.clamp(x::Float32, minval::Float32, maxval::Float32) = ccall("extern air.clamp.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, minval, maxval)
@device_override Base.clamp(x::Float16, minval::Float16, maxval::Float16) = ccall("extern air.clamp.f16", llvmcall, Float16, (Float16, Float16, Float16), x, minval, maxval)

@device_override Base.sign(x::Float32) = ccall("extern air.sign.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.sign(x::Float16) = ccall("extern air.sign.f16", llvmcall, Float16, (Float16,), x)

### Floating Point Intrinsics

## Metal only supports single and half-precision floating-point types (and their vector counterparts)
Expand All @@ -17,13 +25,21 @@ using Base.Math: throw_complex_domainerror
@device_override Base.abs(x::Float32) = ccall("extern air.fabs.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.abs(x::Float16) = ccall("extern air.fabs.f16", llvmcall, Float16, (Float16,), x)

@device_override FastMath.min_fast(x::Float32) = ccall("extern air.fast_fmin.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.min(x::Float32) = ccall("extern air.fmin.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.min(x::Float16) = ccall("extern air.fmin.f16", llvmcall, Float16, (Float16,), x)
@device_override FastMath.min_fast(x::Float32, y::Float32) = ccall("extern air.fast_fmin.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
@device_override Base.min(x::Float32, y::Float32) = ccall("extern air.fmin.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
@device_override Base.min(x::Float16, y::Float16) = ccall("extern air.fmin.f16", llvmcall, Float16, (Float16, Float16), x, y)

@device_override FastMath.min_fast(x::Float32, y::Float32, z::Float32) = ccall("extern air.fast_fmin3.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z)
@device_override Base.min(x::Float32, y::Float32, z::Float32) = ccall("extern air.fmin3.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z)
@device_override Base.min(x::Float16, y::Float16, z::Float16) = ccall("extern air.fmin3.f16", llvmcall, Float16, (Float16, Float16, Float16), x, y, z)

@device_override FastMath.max_fast(x::Float32) = ccall("extern air.fast_fmax.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.max(x::Float32) = ccall("extern air.fmax.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.max(x::Float16) = ccall("extern air.fmax.f16", llvmcall, Float16, (Float16,), x)
@device_override FastMath.max_fast(x::Float32, y::Float32) = ccall("extern air.fast_fmax.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
@device_override Base.max(x::Float32, y::Float32) = ccall("extern air.fmax.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
@device_override Base.max(x::Float16, y::Float16) = ccall("extern air.fmax.f16", llvmcall, Float16, (Float16, Float16), x, y)

@device_override FastMath.max_fast(x::Float32, y::Float32, z::Float32) = ccall("extern air.fast_fmax3.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z)
@device_override Base.max(x::Float32, y::Float32, z::Float32) = ccall("extern air.fmax3.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z)
@device_override Base.max(x::Float16, y::Float16, z::Float16) = ccall("extern air.fmax3.f16", llvmcall, Float16, (Float16, Float16, Float16), x, y, z)

@device_override FastMath.acos_fast(x::Float32) = ccall("extern air.fast_acos.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.acos(x::Float32) = ccall("extern air.acos.f32", llvmcall, Cfloat, (Cfloat,), x)
Expand All @@ -45,6 +61,10 @@ using Base.Math: throw_complex_domainerror
@device_override Base.atan(x::Float32) = ccall("extern air.atan.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.atan(x::Float16) = ccall("extern air.atan.f16", llvmcall, Float16, (Float16,), x)

@device_override FastMath.atan_fast(x::Float32, y::Float32) = ccall("extern air.fast_atan2.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
@device_override Base.atan(x::Float32, y::Float32) = ccall("extern air.atan2.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
@device_override Base.atan(x::Float16, y::Float16) = ccall("extern air.atan2.f16", llvmcall, Float16, (Float16, Float16), x, y)

@device_override FastMath.atanh_fast(x::Float32) = ccall("extern air.fast_atanh.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.atanh(x::Float32) = ccall("extern air.atanh.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.atanh(x::Float16) = ccall("extern air.atanh.f16", llvmcall, Float16, (Float16,), x)
Expand Down Expand Up @@ -240,6 +260,7 @@ end
s = ccall("extern air.sincos.f32", llvmcall, Cfloat, (Cfloat, Ptr{Cfloat}), x, c)
(s, c[])
end
# XXX: Broken
@device_override function Base.sincos(x::Float16)
c = Ref{Float16}()
s = ccall("extern air.sincos.f16", llvmcall, Float16, (Float16, Ptr{Float16}), x, c)
Expand Down Expand Up @@ -267,8 +288,8 @@ end
@device_override Base.tanh(x::Float16) = ccall("extern air.tanh.f16", llvmcall, Float16, (Float16,), x)

@device_function tanpi_fast(x::Float32) = ccall("extern air.fast_tanpi.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_function tanpi(x::Float32) = ccall("extern air.tanpi.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_function tanpi(x::Float16) = ccall("extern air.tanpi.f16", llvmcall, Float16, (Float16,), x)
@device_override Base.tanpi(x::Float32) = ccall("extern air.tanpi.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.tanpi(x::Float16) = ccall("extern air.tanpi.f16", llvmcall, Float16, (Float16,), x)

@device_function trunc_fast(x::Float32) = ccall("extern air.fast_trunc.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.trunc(x::Float32) = ccall("extern air.trunc.f32", llvmcall, Cfloat, (Cfloat,), x)
Expand Down Expand Up @@ -418,7 +439,7 @@ end
j = fma(1.442695f0, a, 12582912.0f0)
j = j - 12582912.0f0
i = unsafe_trunc(Int32, j)
f = fma(j, -6.93145752f-1, a) # log_2_hi
f = fma(j, -6.93145752f-1, a) # log_2_hi
f = fma(j, -1.42860677f-6, f) # log_2_lo

# approximate r = exp(f)-1 on interval [-log(2)/2, +log(2)/2]
Expand Down
279 changes: 235 additions & 44 deletions test/device/intrinsics.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using SpecialFunctions
using Metal: metal_support
using Random
using SpecialFunctions

@testset "arguments" begin
@on_device dispatch_quadgroups_per_threadgroup()
Expand Down Expand Up @@ -103,71 +104,261 @@ end

############################################################################################

MATH_INTR_FUNCS_1_ARG = [
# Common functions
# saturate, # T saturate(T x) Clamp between 0.0 and 1.0
sign, # T sign(T x) returns 0.0 if x is NaN

# float math
acos, # T acos(T x)
asin, # T asin(T x)
asinh, # T asinh(T x)
atan, # T atan(T x)
atanh, # T atanh(T x)
ceil, # T ceil(T x)
cos, # T cos(T x)
cosh, # T cosh(T x)
cospi, # T cospi(T x)
exp, # T exp(T x)
exp2, # T exp2(T x)
exp10, # T exp10(T x)
abs, #T [f]abs(T x)
floor, # T floor(T x)
Metal.fract, # T fract(T x)
# ilogb, # Ti ilogb(T x)
log, # T log(T x)
log2, # T log2(T x)
log10, # T log10(T x)
# Metal.rint, # T rint(T x) # TODO: Add test. Not sure what the behaviour actually is
round, # T round(T x)
Metal.rsqrt, # T rsqrt(T x)
sin, # T sin(T x)
sinh, # T sinh(T x)
sinpi, # T sinpi(T x)
sqrt, # sqrt(T x)
tan, # T tan(T x)
tanh, # T tanh(T x)
tanpi, # T tanpi(T x)
trunc, # T trunc(T x)
]
Metal.rsqrt(x::Float16) = 1 / sqrt(x)
Metal.rsqrt(x::Float32) = 1 / sqrt(x)
Metal.fract(x::Float16) = mod(x, 1)
Metal.fract(x::Float32) = mod(x, 1)

MATH_INTR_FUNCS_2_ARG = [
# Common function
# step, # T step(T edge, T x) Returns 0.0 if x < edge, otherwise it returns 1.0

# float math
atan, # T atan2(T x, T y) Compute arc tangent of y over x.
# fdim, # T fdim(T x, T y)
max, # T [f]max(T x, T y)
min, # T [f]min(T x, T y)
# fmod, # T fmod(T x, T y)
# frexp, # T frexp(T x, Ti &exponent)
# ldexp, # T ldexp(T x, Ti k)
# modf, # T modf(T x, T &intval)
# nextafter, # T nextafter(T x, T y) # Metal 3.1+
# sincos,
hypot, # NOT MSL but tested the same
]

MATH_INTR_FUNCS_3_ARG = [
# Common functions
# mix, # T mix(T x, T y, T a) # x+(y-x)*a
# smoothstep, # T smoothstep(T edge0, T edge1, T x)
fma, # T fma(T a, T b, T c)
max, # T max3(T x, T y, T z)
# median3, # T median3(T x, T y, T z)
min, # T min3(T x, T y, T z)
]

@testset "math" begin
a = ones(Float32,1)
a .* Float32(3.14)
bufferA = MtlArray{eltype(a),length(size(a)),Metal.SharedStorage}(a)
vecA = unsafe_wrap(Vector{Float32}, pointer(bufferA), 1)
# 1-arg functions
@testset "$(fun)()::$T" for fun in MATH_INTR_FUNCS_1_ARG, T in (Float32, Float16)
cpuarr = if fun in [log, log2, log10, Metal.rsqrt, sqrt]
rand(T, 4)
else
T[0.0, -0.0, rand(T), -rand(T)]
end

mtlarr = MtlArray(cpuarr)

mtlout = fill!(similar(mtlarr), 0)

function intr_test(arr)
function kernel(res, arr)
idx = thread_position_in_grid_1d()
arr[idx] = cos(arr[idx])
res[idx] = fun(arr[idx])
return nothing
end
@metal intr_test(bufferA)
synchronize()
@test vecA ≈ cos.(a)
Metal.@sync @metal threads = length(mtlout) kernel(mtlout, mtlarr)
@eval @test Array($mtlout) ≈ $fun.($cpuarr)
end
# 2-arg functions
@testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_2_ARG
N = 4
arr1 = randn(T, N)
arr2 = randn(T, N)
mtlarr1 = MtlArray(arr1)
mtlarr2 = MtlArray(arr2)

mtlout = fill!(similar(mtlarr1), 0)

function intr_test2(arr)
function kernel(res, x, y)
idx = thread_position_in_grid_1d()
arr[idx] = Metal.rsqrt(arr[idx])
res[idx] = fun(x[idx], y[idx])
return nothing
end
@metal intr_test2(bufferA)
synchronize()
Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
@eval @test Array($mtlout) ≈ $fun.($arr1, $arr2)
end
# 3-arg functions
@testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_3_ARG
N = 4
arr1 = randn(T, N)
arr2 = randn(T, N)
arr3 = randn(T, N)

bufferB = MtlArray{eltype(a),length(size(a)),Metal.SharedStorage}(a)
vecB = unsafe_wrap(Vector{Float32}, pointer(bufferB), 1)
mtlarr1 = MtlArray(arr1)
mtlarr2 = MtlArray(arr2)
mtlarr3 = MtlArray(arr3)

function intr_test3(arr_sin, arr_cos)
mtlout = fill!(similar(mtlarr1), 0)

function kernel(res, x, y, z)
idx = thread_position_in_grid_1d()
s, c = sincos(arr_cos[idx])
arr_sin[idx] = s
arr_cos[idx] = c
res[idx] = fun(x[idx], y[idx], z[idx])
return nothing
end
Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2, mtlarr3)
@eval @test Array($mtlout) ≈ $fun.($arr1, $arr2, $arr3)
end
end

@metal intr_test3(bufferA, bufferB)
synchronize()
@test vecA ≈ sin.(a)
@test vecB ≈ cos.(a)
@testset "unique math" begin
@testset "$T" for T in (Float32, Float16)
let # acosh
arr = T[0, rand(T, 3)...] .+ T(1)
buffer = MtlArray(arr)
vec = acosh.(buffer)
@test Array(vec) ≈ acosh.(arr)
end

b = collect(LinRange(nextfloat(-1f0), 10f0, 20))
bufferC = MtlArray(b)
vecC = Array(log1p.(bufferC))
@test vecC ≈ log1p.(b)
let # sincos
N = 4
arr = rand(T, N)
bufferA = MtlArray(arr)
bufferB = MtlArray(arr)
function intr_test3(arr_sin, arr_cos)
idx = thread_position_in_grid_1d()
sinres, cosres = sincos(arr_cos[idx])
arr_sin[idx] = sinres
arr_cos[idx] = cosres
return nothing
end
# Broken with Float16
if T == Float16
@test_broken Metal.@sync @metal threads = N intr_test3(bufferA, bufferB)
else
Metal.@sync @metal threads = N intr_test3(bufferA, bufferB)
@test Array(bufferA) ≈ sin.(arr)
@test Array(bufferB) ≈ cos.(arr)
end
end

let # clamp
N = 4
in = randn(T, N)
minval = fill(T(-0.6), N)
maxval = fill(T(0.6), N)

d = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
bufferD = MtlArray(d)
vecD = Array(SpecialFunctions.erf.(bufferD))
@test vecD ≈ SpecialFunctions.erf.(d)
mtlin = MtlArray(in)
mtlminval = MtlArray(minval)
mtlmaxval = MtlArray(maxval)

mtlout = fill!(similar(mtlin), 0)

function kernel(res, x, y, z)
idx = thread_position_in_grid_1d()
res[idx] = clamp(x[idx], y[idx], z[idx])
return nothing
end
Metal.@sync @metal threads = N kernel(mtlout, mtlin, mtlminval, mtlmaxval)
@test Array(mtlout) == clamp.(in, minval, maxval)
end

e = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
bufferE = MtlArray(e)
vecE = Array(SpecialFunctions.erfc.(bufferE))
@test vecE ≈ SpecialFunctions.erfc.(e)
let #pow
N = 4
arr1 = rand(T, N)
arr2 = rand(T, N)
mtlarr1 = MtlArray(arr1)
mtlarr2 = MtlArray(arr2)

f = collect(LinRange(-1f0, 1f0, 20))
bufferF = MtlArray(f)
vecF = Array(SpecialFunctions.erfinv.(bufferF))
@test vecF ≈ SpecialFunctions.erfinv.(f)
mtlout = fill!(similar(mtlarr1), 0)

f = collect(LinRange(nextfloat(-88f0), 88f0, 100))
bufferF = MtlArray(f)
vecF = Array(expm1.(bufferF))
@test vecF ≈ expm1.(f)
function kernel(res, x, y)
idx = thread_position_in_grid_1d()
res[idx] = x[idx]^y[idx]
return nothing
end
Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
@test Array(mtlout) ≈ arr1 .^ arr2
end

let #powr
N = 4
arr1 = rand(T, N)
arr2 = rand(T, N)
mtlarr1 = MtlArray(arr1)
mtlarr2 = MtlArray(arr2)

mtlout = fill!(similar(mtlarr1), 0)

function kernel(res, x, y)
idx = thread_position_in_grid_1d()
res[idx] = Metal.powr(x[idx], y[idx])
return nothing
end
Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
@test Array(mtlout) ≈ arr1 .^ arr2
end

let # log1p
arr = collect(LinRange(nextfloat(-1.0f0), 10.0f0, 20))
buffer = MtlArray(arr)
vec = Array(log1p.(buffer))
@test vec ≈ log1p.(arr)
end

let # erf
arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
buffer = MtlArray(arr)
vec = Array(SpecialFunctions.erf.(buffer))
@test vec ≈ SpecialFunctions.erf.(arr)
end

let # erfc
arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
buffer = MtlArray(arr)
vec = Array(SpecialFunctions.erfc.(buffer))
@test vec ≈ SpecialFunctions.erfc.(arr)
end

let # erfinv
arr = collect(LinRange(-1.0f0, 1.0f0, 20))
buffer = MtlArray(arr)
vec = Array(SpecialFunctions.erfinv.(buffer))
@test vec ≈ SpecialFunctions.erfinv.(arr)
end

let # expm1
arr = collect(LinRange(nextfloat(-88.0f0), 88.0f0, 100))
buffer = MtlArray(arr)
vec = Array(expm1.(buffer))
@test vec ≈ expm1.(arr)
end
end
end

############################################################################################
Expand Down