diff --git a/src/device/intrinsics/math.jl b/src/device/intrinsics/math.jl index e7544c1f0..a7d1ff372 100644 --- a/src/device/intrinsics/math.jl +++ b/src/device/intrinsics/math.jl @@ -8,6 +8,14 @@ using Base.Math: throw_complex_domainerror # - add support for vector types # - consider emitting LLVM intrinsics and lowering those in the back-end +### Common Intrinsics +@device_function clamp_fast(x::Float32, minval::Float32, maxval::Float32) = ccall("extern air.fast_clamp.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, minval, maxval) +@device_override Base.clamp(x::Float32, minval::Float32, maxval::Float32) = ccall("extern air.clamp.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, minval, maxval) +@device_override Base.clamp(x::Float16, minval::Float16, maxval::Float16) = ccall("extern air.clamp.f16", llvmcall, Float16, (Float16, Float16, Float16), x, minval, maxval) + +@device_override Base.sign(x::Float32) = ccall("extern air.sign.f32", llvmcall, Cfloat, (Cfloat,), x) +@device_override Base.sign(x::Float16) = ccall("extern air.sign.f16", llvmcall, Float16, (Float16,), x) + ### Floating Point Intrinsics ## Metal only supports single and half-precision floating-point types (and their vector counterparts) @@ -17,13 +25,21 @@ using Base.Math: throw_complex_domainerror @device_override Base.abs(x::Float32) = ccall("extern air.fabs.f32", llvmcall, Cfloat, (Cfloat,), x) @device_override Base.abs(x::Float16) = ccall("extern air.fabs.f16", llvmcall, Float16, (Float16,), x) -@device_override FastMath.min_fast(x::Float32) = ccall("extern air.fast_fmin.f32", llvmcall, Cfloat, (Cfloat,), x) -@device_override Base.min(x::Float32) = ccall("extern air.fmin.f32", llvmcall, Cfloat, (Cfloat,), x) -@device_override Base.min(x::Float16) = ccall("extern air.fmin.f16", llvmcall, Float16, (Float16,), x) +@device_override FastMath.min_fast(x::Float32, y::Float32) = ccall("extern air.fast_fmin.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) +@device_override Base.min(x::Float32, y::Float32) = ccall("extern air.fmin.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) +@device_override Base.min(x::Float16, y::Float16) = ccall("extern air.fmin.f16", llvmcall, Float16, (Float16, Float16), x, y) + +@device_override FastMath.min_fast(x::Float32, y::Float32, z::Float32) = ccall("extern air.fast_fmin3.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z) +@device_override Base.min(x::Float32, y::Float32, z::Float32) = ccall("extern air.fmin3.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z) +@device_override Base.min(x::Float16, y::Float16, z::Float16) = ccall("extern air.fmin3.f16", llvmcall, Float16, (Float16, Float16, Float16), x, y, z) -@device_override FastMath.max_fast(x::Float32) = ccall("extern air.fast_fmax.f32", llvmcall, Cfloat, (Cfloat,), x) -@device_override Base.max(x::Float32) = ccall("extern air.fmax.f32", llvmcall, Cfloat, (Cfloat,), x) -@device_override Base.max(x::Float16) = ccall("extern air.fmax.f16", llvmcall, Float16, (Float16,), x) +@device_override FastMath.max_fast(x::Float32, y::Float32) = ccall("extern air.fast_fmax.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) +@device_override Base.max(x::Float32, y::Float32) = ccall("extern air.fmax.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) +@device_override Base.max(x::Float16, y::Float16) = ccall("extern air.fmax.f16", llvmcall, Float16, (Float16, Float16), x, y) + +@device_override FastMath.max_fast(x::Float32, y::Float32, z::Float32) = ccall("extern air.fast_fmax3.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z) +@device_override Base.max(x::Float32, y::Float32, z::Float32) = ccall("extern air.fmax3.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z) +@device_override Base.max(x::Float16, y::Float16, z::Float16) = ccall("extern air.fmax3.f16", llvmcall, Float16, (Float16, Float16, Float16), x, y, z) @device_override FastMath.acos_fast(x::Float32) = ccall("extern air.fast_acos.f32", llvmcall, Cfloat, (Cfloat,), x) @device_override Base.acos(x::Float32) = ccall("extern air.acos.f32", llvmcall, Cfloat, (Cfloat,), x) @@ -45,6 +61,10 @@ using Base.Math: throw_complex_domainerror @device_override Base.atan(x::Float32) = ccall("extern air.atan.f32", llvmcall, Cfloat, (Cfloat,), x) @device_override Base.atan(x::Float16) = ccall("extern air.atan.f16", llvmcall, Float16, (Float16,), x) +@device_override FastMath.atan_fast(x::Float32, y::Float32) = ccall("extern air.fast_atan2.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) +@device_override Base.atan(x::Float32, y::Float32) = ccall("extern air.atan2.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) +@device_override Base.atan(x::Float16, y::Float16) = ccall("extern air.atan2.f16", llvmcall, Float16, (Float16, Float16), x, y) + @device_override FastMath.atanh_fast(x::Float32) = ccall("extern air.fast_atanh.f32", llvmcall, Cfloat, (Cfloat,), x) @device_override Base.atanh(x::Float32) = ccall("extern air.atanh.f32", llvmcall, Cfloat, (Cfloat,), x) @device_override Base.atanh(x::Float16) = ccall("extern air.atanh.f16", llvmcall, Float16, (Float16,), x) @@ -240,6 +260,7 @@ end s = ccall("extern air.sincos.f32", llvmcall, Cfloat, (Cfloat, Ptr{Cfloat}), x, c) (s, c[]) end +# XXX: Broken @device_override function Base.sincos(x::Float16) c = Ref{Float16}() s = ccall("extern air.sincos.f16", llvmcall, Float16, (Float16, Ptr{Float16}), x, c) @@ -267,8 +288,8 @@ end @device_override Base.tanh(x::Float16) = ccall("extern air.tanh.f16", llvmcall, Float16, (Float16,), x) @device_function tanpi_fast(x::Float32) = ccall("extern air.fast_tanpi.f32", llvmcall, Cfloat, (Cfloat,), x) -@device_function tanpi(x::Float32) = ccall("extern air.tanpi.f32", llvmcall, Cfloat, (Cfloat,), x) -@device_function tanpi(x::Float16) = ccall("extern air.tanpi.f16", llvmcall, Float16, (Float16,), x) +@device_override Base.tanpi(x::Float32) = ccall("extern air.tanpi.f32", llvmcall, Cfloat, (Cfloat,), x) +@device_override Base.tanpi(x::Float16) = ccall("extern air.tanpi.f16", llvmcall, Float16, (Float16,), x) @device_function trunc_fast(x::Float32) = ccall("extern air.fast_trunc.f32", llvmcall, Cfloat, (Cfloat,), x) @device_override Base.trunc(x::Float32) = ccall("extern air.trunc.f32", llvmcall, Cfloat, (Cfloat,), x) @@ -418,7 +439,7 @@ end j = fma(1.442695f0, a, 12582912.0f0) j = j - 12582912.0f0 i = unsafe_trunc(Int32, j) - f = fma(j, -6.93145752f-1, a) # log_2_hi + f = fma(j, -6.93145752f-1, a) # log_2_hi f = fma(j, -1.42860677f-6, f) # log_2_lo # approximate r = exp(f)-1 on interval [-log(2)/2, +log(2)/2] diff --git a/test/device/intrinsics.jl b/test/device/intrinsics.jl index c100b2ddd..78f5d2db6 100644 --- a/test/device/intrinsics.jl +++ b/test/device/intrinsics.jl @@ -1,5 +1,6 @@ -using SpecialFunctions using Metal: metal_support +using Random +using SpecialFunctions @testset "arguments" begin @on_device dispatch_quadgroups_per_threadgroup() @@ -103,71 +104,261 @@ end ############################################################################################ +MATH_INTR_FUNCS_1_ARG = [ + # Common functions + # saturate, # T saturate(T x) Clamp between 0.0 and 1.0 + sign, # T sign(T x) returns 0.0 if x is NaN + + # float math + acos, # T acos(T x) + asin, # T asin(T x) + asinh, # T asinh(T x) + atan, # T atan(T x) + atanh, # T atanh(T x) + ceil, # T ceil(T x) + cos, # T cos(T x) + cosh, # T cosh(T x) + cospi, # T cospi(T x) + exp, # T exp(T x) + exp2, # T exp2(T x) + exp10, # T exp10(T x) + abs, #T [f]abs(T x) + floor, # T floor(T x) + Metal.fract, # T fract(T x) + # ilogb, # Ti ilogb(T x) + log, # T log(T x) + log2, # T log2(T x) + log10, # T log10(T x) + # Metal.rint, # T rint(T x) # TODO: Add test. Not sure what the behaviour actually is + round, # T round(T x) + Metal.rsqrt, # T rsqrt(T x) + sin, # T sin(T x) + sinh, # T sinh(T x) + sinpi, # T sinpi(T x) + sqrt, # sqrt(T x) + tan, # T tan(T x) + tanh, # T tanh(T x) + tanpi, # T tanpi(T x) + trunc, # T trunc(T x) +] +Metal.rsqrt(x::Float16) = 1 / sqrt(x) +Metal.rsqrt(x::Float32) = 1 / sqrt(x) +Metal.fract(x::Float16) = mod(x, 1) +Metal.fract(x::Float32) = mod(x, 1) + +MATH_INTR_FUNCS_2_ARG = [ + # Common function + # step, # T step(T edge, T x) Returns 0.0 if x < edge, otherwise it returns 1.0 + + # float math + atan, # T atan2(T x, T y) Compute arc tangent of y over x. + # fdim, # T fdim(T x, T y) + max, # T [f]max(T x, T y) + min, # T [f]min(T x, T y) + # fmod, # T fmod(T x, T y) + # frexp, # T frexp(T x, Ti &exponent) + # ldexp, # T ldexp(T x, Ti k) + # modf, # T modf(T x, T &intval) + # nextafter, # T nextafter(T x, T y) # Metal 3.1+ + # sincos, + hypot, # NOT MSL but tested the same +] + +MATH_INTR_FUNCS_3_ARG = [ + # Common functions + # mix, # T mix(T x, T y, T a) # x+(y-x)*a + # smoothstep, # T smoothstep(T edge0, T edge1, T x) + fma, # T fma(T a, T b, T c) + max, # T max3(T x, T y, T z) + # median3, # T median3(T x, T y, T z) + min, # T min3(T x, T y, T z) +] + @testset "math" begin - a = ones(Float32,1) - a .* Float32(3.14) - bufferA = MtlArray{eltype(a),length(size(a)),Metal.SharedStorage}(a) - vecA = unsafe_wrap(Vector{Float32}, pointer(bufferA), 1) +# 1-arg functions +@testset "$(fun)()::$T" for fun in MATH_INTR_FUNCS_1_ARG, T in (Float32, Float16) + cpuarr = if fun in [log, log2, log10, Metal.rsqrt, sqrt] + rand(T, 4) + else + T[0.0, -0.0, rand(T), -rand(T)] + end + + mtlarr = MtlArray(cpuarr) + + mtlout = fill!(similar(mtlarr), 0) - function intr_test(arr) + function kernel(res, arr) idx = thread_position_in_grid_1d() - arr[idx] = cos(arr[idx]) + res[idx] = fun(arr[idx]) return nothing end - @metal intr_test(bufferA) - synchronize() - @test vecA ≈ cos.(a) + Metal.@sync @metal threads = length(mtlout) kernel(mtlout, mtlarr) + @eval @test Array($mtlout) ≈ $fun.($cpuarr) +end +# 2-arg functions +@testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_2_ARG + N = 4 + arr1 = randn(T, N) + arr2 = randn(T, N) + mtlarr1 = MtlArray(arr1) + mtlarr2 = MtlArray(arr2) + + mtlout = fill!(similar(mtlarr1), 0) - function intr_test2(arr) + function kernel(res, x, y) idx = thread_position_in_grid_1d() - arr[idx] = Metal.rsqrt(arr[idx]) + res[idx] = fun(x[idx], y[idx]) return nothing end - @metal intr_test2(bufferA) - synchronize() + Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2) + @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2) +end +# 3-arg functions +@testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_3_ARG + N = 4 + arr1 = randn(T, N) + arr2 = randn(T, N) + arr3 = randn(T, N) - bufferB = MtlArray{eltype(a),length(size(a)),Metal.SharedStorage}(a) - vecB = unsafe_wrap(Vector{Float32}, pointer(bufferB), 1) + mtlarr1 = MtlArray(arr1) + mtlarr2 = MtlArray(arr2) + mtlarr3 = MtlArray(arr3) - function intr_test3(arr_sin, arr_cos) + mtlout = fill!(similar(mtlarr1), 0) + + function kernel(res, x, y, z) idx = thread_position_in_grid_1d() - s, c = sincos(arr_cos[idx]) - arr_sin[idx] = s - arr_cos[idx] = c + res[idx] = fun(x[idx], y[idx], z[idx]) return nothing end + Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2, mtlarr3) + @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2, $arr3) +end +end - @metal intr_test3(bufferA, bufferB) - synchronize() - @test vecA ≈ sin.(a) - @test vecB ≈ cos.(a) +@testset "unique math" begin +@testset "$T" for T in (Float32, Float16) + let # acosh + arr = T[0, rand(T, 3)...] .+ T(1) + buffer = MtlArray(arr) + vec = acosh.(buffer) + @test Array(vec) ≈ acosh.(arr) + end - b = collect(LinRange(nextfloat(-1f0), 10f0, 20)) - bufferC = MtlArray(b) - vecC = Array(log1p.(bufferC)) - @test vecC ≈ log1p.(b) + let # sincos + N = 4 + arr = rand(T, N) + bufferA = MtlArray(arr) + bufferB = MtlArray(arr) + function intr_test3(arr_sin, arr_cos) + idx = thread_position_in_grid_1d() + sinres, cosres = sincos(arr_cos[idx]) + arr_sin[idx] = sinres + arr_cos[idx] = cosres + return nothing + end + # Broken with Float16 + if T == Float16 + @test_broken Metal.@sync @metal threads = N intr_test3(bufferA, bufferB) + else + Metal.@sync @metal threads = N intr_test3(bufferA, bufferB) + @test Array(bufferA) ≈ sin.(arr) + @test Array(bufferB) ≈ cos.(arr) + end + end + let # clamp + N = 4 + in = randn(T, N) + minval = fill(T(-0.6), N) + maxval = fill(T(0.6), N) - d = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20)) - bufferD = MtlArray(d) - vecD = Array(SpecialFunctions.erf.(bufferD)) - @test vecD ≈ SpecialFunctions.erf.(d) + mtlin = MtlArray(in) + mtlminval = MtlArray(minval) + mtlmaxval = MtlArray(maxval) + mtlout = fill!(similar(mtlin), 0) + + function kernel(res, x, y, z) + idx = thread_position_in_grid_1d() + res[idx] = clamp(x[idx], y[idx], z[idx]) + return nothing + end + Metal.@sync @metal threads = N kernel(mtlout, mtlin, mtlminval, mtlmaxval) + @test Array(mtlout) == clamp.(in, minval, maxval) + end - e = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20)) - bufferE = MtlArray(e) - vecE = Array(SpecialFunctions.erfc.(bufferE)) - @test vecE ≈ SpecialFunctions.erfc.(e) + let #pow + N = 4 + arr1 = rand(T, N) + arr2 = rand(T, N) + mtlarr1 = MtlArray(arr1) + mtlarr2 = MtlArray(arr2) - f = collect(LinRange(-1f0, 1f0, 20)) - bufferF = MtlArray(f) - vecF = Array(SpecialFunctions.erfinv.(bufferF)) - @test vecF ≈ SpecialFunctions.erfinv.(f) + mtlout = fill!(similar(mtlarr1), 0) - f = collect(LinRange(nextfloat(-88f0), 88f0, 100)) - bufferF = MtlArray(f) - vecF = Array(expm1.(bufferF)) - @test vecF ≈ expm1.(f) + function kernel(res, x, y) + idx = thread_position_in_grid_1d() + res[idx] = x[idx]^y[idx] + return nothing + end + Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2) + @test Array(mtlout) ≈ arr1 .^ arr2 + end + + let #powr + N = 4 + arr1 = rand(T, N) + arr2 = rand(T, N) + mtlarr1 = MtlArray(arr1) + mtlarr2 = MtlArray(arr2) + + mtlout = fill!(similar(mtlarr1), 0) + + function kernel(res, x, y) + idx = thread_position_in_grid_1d() + res[idx] = Metal.powr(x[idx], y[idx]) + return nothing + end + Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2) + @test Array(mtlout) ≈ arr1 .^ arr2 + end + + let # log1p + arr = collect(LinRange(nextfloat(-1.0f0), 10.0f0, 20)) + buffer = MtlArray(arr) + vec = Array(log1p.(buffer)) + @test vec ≈ log1p.(arr) + end + + let # erf + arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20)) + buffer = MtlArray(arr) + vec = Array(SpecialFunctions.erf.(buffer)) + @test vec ≈ SpecialFunctions.erf.(arr) + end + + let # erfc + arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20)) + buffer = MtlArray(arr) + vec = Array(SpecialFunctions.erfc.(buffer)) + @test vec ≈ SpecialFunctions.erfc.(arr) + end + + let # erfinv + arr = collect(LinRange(-1.0f0, 1.0f0, 20)) + buffer = MtlArray(arr) + vec = Array(SpecialFunctions.erfinv.(buffer)) + @test vec ≈ SpecialFunctions.erfinv.(arr) + end + + let # expm1 + arr = collect(LinRange(nextfloat(-88.0f0), 88.0f0, 100)) + buffer = MtlArray(arr) + vec = Array(expm1.(buffer)) + @test vec ≈ expm1.(arr) + end +end end ############################################################################################