diff --git a/lib/intrinsics/Project.toml b/lib/intrinsics/Project.toml index 9e854841..a8c9c02f 100644 --- a/lib/intrinsics/Project.toml +++ b/lib/intrinsics/Project.toml @@ -9,9 +9,16 @@ GPUToolbox = "096a3bc2-3ced-46d0-87f4-dd12716f4bfc" LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +[weakdeps] +SIMD = "fdea26ae-647d-5447-a871-4b548cad5224" + +[extensions] +SPIRVIntrinsicsSIMDExt = "SIMD" + [compat] ExprTools = "0.1" GPUToolbox = "0.2, 0.3" LLVM = "9.1" +SIMD = "3.6" SpecialFunctions = "1.3, 2" julia = "1.10" diff --git a/lib/intrinsics/ext/SPIRVIntrinsicsSIMDExt.jl b/lib/intrinsics/ext/SPIRVIntrinsicsSIMDExt.jl new file mode 100644 index 00000000..1ecaa610 --- /dev/null +++ b/lib/intrinsics/ext/SPIRVIntrinsicsSIMDExt.jl @@ -0,0 +1,121 @@ +module SPIRVIntrinsicsSIMDExt + +using SPIRVIntrinsics +using SPIRVIntrinsics: @device_override, @device_function, @builtin_ccall, @typed_ccall +using SIMD +import SpecialFunctions + +const known_intrinsics = String[] + +# Generate vectorized math intrinsics +for N in [2, 3, 4, 8, 16], T in [Float16, Float32, Float64] + VT = :(Vec{$N,$T}) + LVT = :(SIMD.LVec{$N,$T}) + + @eval begin + # Unary operations + @device_override Base.acos(x::$VT) = $VT(@builtin_ccall("acos", $LVT, ($LVT,), x.data)) + @device_override Base.acosh(x::$VT) = $VT(@builtin_ccall("acosh", $LVT, ($LVT,), x.data)) + @device_function SPIRVIntrinsics.acospi(x::$VT) = $VT(@builtin_ccall("acospi", $LVT, ($LVT,), x.data)) + + @device_override Base.asin(x::$VT) = $VT(@builtin_ccall("asin", $LVT, ($LVT,), x.data)) + @device_override Base.asinh(x::$VT) = $VT(@builtin_ccall("asinh", $LVT, ($LVT,), x.data)) + @device_function SPIRVIntrinsics.asinpi(x::$VT) = $VT(@builtin_ccall("asinpi", $LVT, ($LVT,), x.data)) + + @device_override Base.atan(x::$VT) = $VT(@builtin_ccall("atan", $LVT, ($LVT,), x.data)) + @device_override Base.atanh(x::$VT) = $VT(@builtin_ccall("atanh", $LVT, ($LVT,), x.data)) + @device_function SPIRVIntrinsics.atanpi(x::$VT) = $VT(@builtin_ccall("atanpi", $LVT, ($LVT,), x.data)) + + @device_override Base.cbrt(x::$VT) = $VT(@builtin_ccall("cbrt", $LVT, ($LVT,), x.data)) + @device_override Base.ceil(x::$VT) = $VT(@builtin_ccall("ceil", $LVT, ($LVT,), x.data)) + + @device_override Base.cos(x::$VT) = $VT(@builtin_ccall("cos", $LVT, ($LVT,), x.data)) + @device_override Base.cosh(x::$VT) = $VT(@builtin_ccall("cosh", $LVT, ($LVT,), x.data)) + @device_override Base.cospi(x::$VT) = $VT(@builtin_ccall("cospi", $LVT, ($LVT,), x.data)) + + @device_override SpecialFunctions.erfc(x::$VT) = $VT(@builtin_ccall("erfc", $LVT, ($LVT,), x.data)) + @device_override SpecialFunctions.erf(x::$VT) = $VT(@builtin_ccall("erf", $LVT, ($LVT,), x.data)) + + @device_override Base.exp(x::$VT) = $VT(@builtin_ccall("exp", $LVT, ($LVT,), x.data)) + @device_override Base.exp2(x::$VT) = $VT(@builtin_ccall("exp2", $LVT, ($LVT,), x.data)) + @device_override Base.exp10(x::$VT) = $VT(@builtin_ccall("exp10", $LVT, ($LVT,), x.data)) + @device_override Base.expm1(x::$VT) = $VT(@builtin_ccall("expm1", $LVT, ($LVT,), x.data)) + + @device_override Base.abs(x::$VT) = $VT(@builtin_ccall("fabs", $LVT, ($LVT,), x.data)) + @device_override Base.floor(x::$VT) = $VT(@builtin_ccall("floor", $LVT, ($LVT,), x.data)) + + @device_override SpecialFunctions.loggamma(x::$VT) = $VT(@builtin_ccall("lgamma", $LVT, ($LVT,), x.data)) + + @device_override Base.log(x::$VT) = $VT(@builtin_ccall("log", $LVT, ($LVT,), x.data)) + @device_override Base.log2(x::$VT) = $VT(@builtin_ccall("log2", $LVT, ($LVT,), x.data)) + @device_override Base.log10(x::$VT) = $VT(@builtin_ccall("log10", $LVT, ($LVT,), x.data)) + @device_override Base.log1p(x::$VT) = $VT(@builtin_ccall("log1p", $LVT, ($LVT,), x.data)) + @device_function SPIRVIntrinsics.logb(x::$VT) = $VT(@builtin_ccall("logb", $LVT, ($LVT,), x.data)) + + @device_function SPIRVIntrinsics.rint(x::$VT) = $VT(@builtin_ccall("rint", $LVT, ($LVT,), x.data)) + @device_override Base.round(x::$VT) = $VT(@builtin_ccall("round", $LVT, ($LVT,), x.data)) + @device_function SPIRVIntrinsics.rsqrt(x::$VT) = $VT(@builtin_ccall("rsqrt", $LVT, ($LVT,), x.data)) + + @device_override Base.sin(x::$VT) = $VT(@builtin_ccall("sin", $LVT, ($LVT,), x.data)) + @device_override Base.sinh(x::$VT) = $VT(@builtin_ccall("sinh", $LVT, ($LVT,), x.data)) + @device_override Base.sinpi(x::$VT) = $VT(@builtin_ccall("sinpi", $LVT, ($LVT,), x.data)) + + @device_override Base.sqrt(x::$VT) = $VT(@builtin_ccall("sqrt", $LVT, ($LVT,), x.data)) + + @device_override Base.tan(x::$VT) = $VT(@builtin_ccall("tan", $LVT, ($LVT,), x.data)) + @device_override Base.tanh(x::$VT) = $VT(@builtin_ccall("tanh", $LVT, ($LVT,), x.data)) + @device_override Base.tanpi(x::$VT) = $VT(@builtin_ccall("tanpi", $LVT, ($LVT,), x.data)) + + @device_override SpecialFunctions.gamma(x::$VT) = $VT(@builtin_ccall("tgamma", $LVT, ($LVT,), x.data)) + + @device_override Base.trunc(x::$VT) = $VT(@builtin_ccall("trunc", $LVT, ($LVT,), x.data)) + + # Binary operations + @device_override Base.atan(y::$VT, x::$VT) = $VT(@builtin_ccall("atan2", $LVT, ($LVT, $LVT), y.data, x.data)) + @device_function SPIRVIntrinsics.atanpi(y::$VT, x::$VT) = $VT(@builtin_ccall("atan2pi", $LVT, ($LVT, $LVT), y.data, x.data)) + + @device_override Base.copysign(x::$VT, y::$VT) = $VT(@builtin_ccall("copysign", $LVT, ($LVT, $LVT), x.data, y.data)) + @device_function SPIRVIntrinsics.dim(x::$VT, y::$VT) = $VT(@builtin_ccall("fdim", $LVT, ($LVT, $LVT), x.data, y.data)) + + @device_override Base.hypot(x::$VT, y::$VT) = $VT(@builtin_ccall("hypot", $LVT, ($LVT, $LVT), x.data, y.data)) + + @device_override Base.max(x::$VT, y::$VT) = $VT(@builtin_ccall("fmax", $LVT, ($LVT, $LVT), x.data, y.data)) + @device_override Base.min(x::$VT, y::$VT) = $VT(@builtin_ccall("fmin", $LVT, ($LVT, $LVT), x.data, y.data)) + + @device_function SPIRVIntrinsics.maxmag(x::$VT, y::$VT) = $VT(@builtin_ccall("maxmag", $LVT, ($LVT, $LVT), x.data, y.data)) + @device_function SPIRVIntrinsics.minmag(x::$VT, y::$VT) = $VT(@builtin_ccall("minmag", $LVT, ($LVT, $LVT), x.data, y.data)) + + @device_function SPIRVIntrinsics.nextafter(x::$VT, y::$VT) = $VT(@builtin_ccall("nextafter", $LVT, ($LVT, $LVT), x.data, y.data)) + + @device_override Base.:(^)(x::$VT, y::$VT) = $VT(@builtin_ccall("pow", $LVT, ($LVT, $LVT), x.data, y.data)) + @device_function SPIRVIntrinsics.powr(x::$VT, y::$VT) = $VT(@builtin_ccall("powr", $LVT, ($LVT, $LVT), x.data, y.data)) + + @device_override Base.rem(x::$VT, y::$VT) = $VT(@builtin_ccall("remainder", $LVT, ($LVT, $LVT), x.data, y.data)) + + # Ternary operations + @device_override Base.fma(a::$VT, b::$VT, c::$VT) = $VT(@builtin_ccall("fma", $LVT, ($LVT, $LVT, $LVT), a.data, b.data, c.data)) + @device_function SPIRVIntrinsics.mad(a::$VT, b::$VT, c::$VT) = $VT(@builtin_ccall("mad", $LVT, ($LVT, $LVT, $LVT), a.data, b.data, c.data)) + end + + # Special operations with Int32 parameters + VIntT = :(Vec{$N,Int32}) + LVIntT = :(SIMD.LVec{$N,Int32}) + + @eval begin + @device_function SPIRVIntrinsics.ilogb(x::$VT) = $VIntT(@builtin_ccall("ilogb", $LVIntT, ($LVT,), x.data)) + @device_override Base.ldexp(x::$VT, k::$VIntT) = $VT(@builtin_ccall("ldexp", $LVT, ($LVT, $LVIntT), x.data, k.data)) + @device_override Base.:(^)(x::$VT, y::$VIntT) = $VT(@builtin_ccall("pown", $LVT, ($LVT, $LVIntT), x.data, y.data)) + @device_function SPIRVIntrinsics.rootn(x::$VT, y::$VIntT) = $VT(@builtin_ccall("rootn", $LVT, ($LVT, $LVIntT), x.data, y.data)) + end +end + +# nan functions - take unsigned integer codes and return floats +for N in [2, 3, 4, 8, 16] + @eval begin + @device_function SPIRVIntrinsics.nan(nancode::Vec{$N,UInt16}) = Vec{$N,Float16}(@builtin_ccall("nan", SIMD.LVec{$N,Float16}, (SIMD.LVec{$N,UInt16},), nancode.data)) + @device_function SPIRVIntrinsics.nan(nancode::Vec{$N,UInt32}) = Vec{$N,Float32}(@builtin_ccall("nan", SIMD.LVec{$N,Float32}, (SIMD.LVec{$N,UInt32},), nancode.data)) + @device_function SPIRVIntrinsics.nan(nancode::Vec{$N,UInt64}) = Vec{$N,Float64}(@builtin_ccall("nan", SIMD.LVec{$N,Float64}, (SIMD.LVec{$N,UInt64},), nancode.data)) + end +end + +end # module diff --git a/lib/intrinsics/src/math.jl b/lib/intrinsics/src/math.jl index d51d603b..39117bb1 100644 --- a/lib/intrinsics/src/math.jl +++ b/lib/intrinsics/src/math.jl @@ -1,7 +1,7 @@ # Math Functions # TODO: vector types -const generic_types = [Float32,Float64] +const generic_types = [Float16, Float32, Float64] const generic_types_float = [Float32] const generic_types_double = [Float64] @@ -33,7 +33,7 @@ for gentype in generic_types @device_override Base.cos(x::$gentype) = @builtin_ccall("cos", $gentype, ($gentype,), x) @device_override Base.cosh(x::$gentype) = @builtin_ccall("cosh", $gentype, ($gentype,), x) -@device_function cospi(x::$gentype) = @builtin_ccall("cospi", $gentype, ($gentype,), x) +@device_override Base.cospi(x::$gentype) = @builtin_ccall("cospi", $gentype, ($gentype,), x) @device_override SpecialFunctions.erfc(x::$gentype) = @builtin_ccall("erfc", $gentype, ($gentype,), x) @device_override SpecialFunctions.erf(x::$gentype) = @builtin_ccall("erf", $gentype, ($gentype,), x) @@ -59,7 +59,10 @@ for gentype in generic_types #@device_override Base.mod(x::$gentype, y::$gentype) = @builtin_ccall("fmod", $gentype, ($gentype, $gentype), x, y) # fract(x::$gentype, $gentype *iptr) = @builtin_ccall("fract", $gentype, ($gentype, $gentype *), x, iptr) -@device_override Base.hypot(x::$gentype, y::$gentype) = @builtin_ccall("hypot", $gentype, ($gentype, $gentype), x, y) +# TODO: remove once https://github.com/pocl/pocl/issues/2034 is addressed +if $gentype != Float16 + @device_override Base.hypot(x::$gentype, y::$gentype) = @builtin_ccall("hypot", $gentype, ($gentype, $gentype), x, y) +end @device_override SpecialFunctions.loggamma(x::$gentype) = @builtin_ccall("lgamma", $gentype, ($gentype,), x) @@ -81,8 +84,6 @@ for gentype in generic_types @device_override Base.:(^)(x::$gentype, y::$gentype) = @builtin_ccall("pow", $gentype, ($gentype, $gentype), x, y) @device_function powr(x::$gentype, y::$gentype) = @builtin_ccall("powr", $gentype, ($gentype, $gentype), x, y) -@device_override Base.rem(x::$gentype, y::$gentype) = @builtin_ccall("remainder", $gentype, ($gentype, $gentype), x, y) - @device_function rint(x::$gentype) = @builtin_ccall("rint", $gentype, ($gentype,), x) @device_override Base.round(x::$gentype) = @builtin_ccall("round", $gentype, ($gentype,), x) @@ -100,13 +101,13 @@ for gentype in generic_types return sinval, cosval[] end @device_override Base.sinh(x::$gentype) = @builtin_ccall("sinh", $gentype, ($gentype,), x) -@device_function sinpi(x::$gentype) = @builtin_ccall("sinpi", $gentype, ($gentype,), x) +@device_override Base.sinpi(x::$gentype) = @builtin_ccall("sinpi", $gentype, ($gentype,), x) @device_override Base.sqrt(x::$gentype) = @builtin_ccall("sqrt", $gentype, ($gentype,), x) @device_override Base.tan(x::$gentype) = @builtin_ccall("tan", $gentype, ($gentype,), x) @device_override Base.tanh(x::$gentype) = @builtin_ccall("tanh", $gentype, ($gentype,), x) -@device_function tanpi(x::$gentype) = @builtin_ccall("tanpi", $gentype, ($gentype,), x) +@device_override Base.tanpi(x::$gentype) = @builtin_ccall("tanpi", $gentype, ($gentype,), x) @device_override SpecialFunctions.gamma(x::$gentype) = @builtin_ccall("tgamma", $gentype, ($gentype,), x) @@ -151,11 +152,13 @@ end # frexp(x::Float64{n}, Int32{n} *exp) = @builtin_ccall("frexp", Float64{n}, (Float64{n}, Int32{n} *), x, exp) # frexp(x::Float64, Int32 *exp) = @builtin_ccall("frexp", Float64, (Float64, Int32 *), x, exp) +@device_function ilogb(x::Float16) = @builtin_ccall("ilogb", Int32, (Float16,), x) # ilogb(x::Float32{n}) = @builtin_ccall("ilogb", Int32{n}, (Float32{n},), x) @device_function ilogb(x::Float32) = @builtin_ccall("ilogb", Int32, (Float32,), x) # ilogb(x::Float64{n}) = @builtin_ccall("ilogb", Int32{n}, (Float64{n},), x) @device_function ilogb(x::Float64) = @builtin_ccall("ilogb", Int32, (Float64,), x) +@device_override Base.ldexp(x::Float16, k::Int32) = @builtin_ccall("ldexp", Float16, (Float16, Int32), x, k) # ldexp(x::Float32{n}, k::Int32{n}) = @builtin_ccall("ldexp", Float32{n}, (Float32{n}, Int32{n}), x, k) # ldexp(x::Float32{n}, k::Int32) = @builtin_ccall("ldexp", Float32{n}, (Float32{n}, Int32), x, k) @device_override Base.ldexp(x::Float32, k::Int32) = @builtin_ccall("ldexp", Float32, (Float32, Int32), x, k) @@ -168,11 +171,13 @@ end # lgamma_r(x::Float64{n}, Int32{n} *signp) = @builtin_ccall("lgamma_r", Float64{n}, (Float64{n}, Int32{n} *), x, signp) # Float64 lgamma_r(x::Float64, Int32 *signp) = @builtin_ccall("lgamma_r", Float64, (Float64, Int32 *), x, signp) +@device_function nan(nancode::UInt16) = @builtin_ccall("nan", Float16, (UInt16,), nancode) # nan(nancode::uintn) = @builtin_ccall("nan", Float32{n}, (uintn,), nancode) @device_function nan(nancode::UInt32) = @builtin_ccall("nan", Float32, (UInt32,), nancode) # nan(nancode::UInt64{n}) = @builtin_ccall("nan", Float64{n}, (UInt64{n},), nancode) @device_function nan(nancode::UInt64) = @builtin_ccall("nan", Float64, (UInt64,), nancode) +@device_override Base.:(^)(x::Float16, y::Int32) = @builtin_ccall("pown", Float16, (Float16, Int32), x, y) # pown(x::Float32{n}, y::Int32{n}) = @builtin_ccall("pown", Float32{n}, (Float32{n}, Int32{n}), x, y) @device_override Base.:(^)(x::Float32, y::Int32) = @builtin_ccall("pown", Float32, (Float32, Int32), x, y) # pown(x::Float64{n}, y::Int32{n}) = @builtin_ccall("pown", Float64{n}, (Float64{n}, Int32{n}), x, y) @@ -183,10 +188,11 @@ end # remquo(x::Float64{n}, y::Float64{n}, Int32{n} *quo) = @builtin_ccall("remquo", Float64{n}, (Float64{n}, Float64{n}, Int32{n} *), x, y, quo) # remquo(x::Float64, y::Float64, Int32 *quo) = @builtin_ccall("remquo", Float64, (Float64, Float64, Int32 *), x, y, quo) +@device_function rootn(x::Float16, y::Int32) = @builtin_ccall("rootn", Float16, (Float16, Int32), x, y) # rootn(x::Float32{n}, y::Int32{n}) = @builtin_ccall("rootn", Float32{n}, (Float32{n}, Int32{n}), x, y) @device_function rootn(x::Float32, y::Int32) = @builtin_ccall("rootn", Float32, (Float32, Int32), x, y) # rootn(x::Float64{n}, y::Int32{n}) = @builtin_ccall("rootn", Float64{n}, (Float64{n}, Int32{n}), x, y) -# rootn(x::Float64, y::Int32) = @builtin_ccall("rootn", Float64{n}, (Float64, Int32), x, y) +@device_function rootn(x::Float64, y::Int32) = @builtin_ccall("rootn", Float64, (Float64, Int32), x, y) # TODO: half and native diff --git a/lib/intrinsics/src/utils.jl b/lib/intrinsics/src/utils.jl index e1a5a939..3e81fe74 100644 --- a/lib/intrinsics/src/utils.jl +++ b/lib/intrinsics/src/utils.jl @@ -26,6 +26,8 @@ macro builtin_ccall(name, ret, argtypes, args...) "c" elseif T == UInt8 "h" + elseif T == Float16 + "Dh" elseif T == Float32 "f" elseif T == Float64 @@ -61,17 +63,18 @@ macro builtin_ccall(name, ret, argtypes, args...) error("Unknown type $T") end end + mangle(::Type{NTuple{N, VecElement{T}}}) where {N, T} = "Dv$(N)_" * mangle(T) # C++-style mangling; very limited to just support these intrinsics # TODO: generalize for use with other intrinsics? do we need to mangle those? mangled = "_Z$(length(name))$name" for t in argtypes # with `@eval @builtin_ccall`, we get actual types in the ast, otherwise symbols - t = (isa(t, Symbol) || isa(t, Expr)) ? eval(t) : t + t = (isa(t, Symbol) || isa(t, Expr)) ? __module__.eval(t) : t mangled *= mangle(t) end - push!(known_intrinsics, mangled) + push!(__module__.known_intrinsics, mangled) esc(quote @typed_ccall($mangled, llvmcall, $ret, ($(argtypes...),), $(args...)) end) @@ -85,7 +88,7 @@ Base.Experimental.@MethodTable(method_table) macro device_override(ex) esc(quote - Base.Experimental.@overlay(method_table, $ex) + Base.Experimental.@overlay($method_table, $ex) end) end diff --git a/src/compiler/compilation.jl b/src/compiler/compilation.jl index 1e6fc506..4ed28111 100644 --- a/src/compiler/compilation.jl +++ b/src/compiler/compilation.jl @@ -16,6 +16,9 @@ GPUCompiler.isintrinsic(job::OpenCLCompilerJob, fn::String) = Tuple{CompilerJob{SPIRVCompilerTarget}, typeof(fn)}, job, fn) || in(fn, known_intrinsics) || + let SPIRVIntrinsicsSIMDExt = Base.get_extension(SPIRVIntrinsics, :SPIRVIntrinsicsSIMDExt) + SPIRVIntrinsicsSIMDExt !== nothing && in(fn, SPIRVIntrinsicsSIMDExt.known_intrinsics) + end || contains(fn, "__spirv_") diff --git a/test/Project.toml b/test/Project.toml index a44cae6b..57ae7ff9 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -12,6 +12,7 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SIMD = "fdea26ae-647d-5447-a871-4b548cad5224" SPIRVIntrinsics = "71d1d633-e7e8-4a92-83a1-de8814b09ba8" SPIRV_LLVM_Backend_jll = "4376b9bf-cff8-51b6-bb48-39421dff0d0c" SPIRV_LLVM_Translator_jll = "4a5d46fc-d8cf-5151-a261-86b458210efb" diff --git a/test/atomics.jl b/test/atomics.jl index ce69ca30..f46a535b 100644 --- a/test/atomics.jl +++ b/test/atomics.jl @@ -1,4 +1,4 @@ -using SPIRVIntrinsics: @builtin_ccall, @typed_ccall, LLVMPtr +using SPIRVIntrinsics: @builtin_ccall, @typed_ccall, LLVMPtr, known_intrinsics @testset "atomics" begin diff --git a/test/intrinsics.jl b/test/intrinsics.jl index 958b258e..ff6ded40 100644 --- a/test/intrinsics.jl +++ b/test/intrinsics.jl @@ -1,3 +1,20 @@ +using SIMD + +function call_on_device(f, args...) + function kernel(res, f, args...) + res[] = f(args...) + return + end + T = OpenCL.code_typed(() -> f(args...), ())[][2] + res = CLArray{T, 0}(undef) + @opencl kernel(res, f, args...) + return OpenCL.@allowscalar res[] +end + +const float_types = filter(x -> x <: Base.IEEEFloat, GPUArraysTestSuite.supported_eltypes(CLArray)) +const ispocl = cl.platform().name == "Portable Computing Language" +const simd_ns = [2, 3, 4, 8, 16] + @testset "intrinsics" begin @testset "barrier" begin @@ -49,5 +66,161 @@ end end +@testset "math" begin + +@testset "unary - $T" for T in float_types + @testset "$f" for f in [ + acos, acosh, + asin, asinh, + atan, atanh, + cbrt, + ceil, + cos, cosh, cospi, + exp, exp2, exp10, expm1, + abs, + floor, + log, log2, log10, log1p, + round, + sin, sinh, sinpi, + sqrt, + tan, tanh, tanpi, + trunc, + ] + x = rand(T) + if f == acosh + x += 1 + end + broken = ispocl && T == Float16 && f in [acosh, asinh, atanh, cbrt, cospi, expm1, log1p, sinpi, tanpi] + @test call_on_device(f, x) ≈ f(x) broken = broken + end +end + +@testset "binary - $T" for T in float_types + @testset "$f" for f in [ + atan, + copysign, + max, + min, + hypot, + (^), + ] + x = rand(T) + y = rand(T) + broken = ispocl && T == Float16 && f == atan + @test call_on_device(f, x, y) ≈ f(x, y) broken = broken + end +end + +@testset "ternary - $T" for T in float_types + @testset "$f" for f in [ + fma, + ] + x = rand(T) + y = rand(T) + z = rand(T) + @test call_on_device(f, x, y, z) ≈ f(x, y, z) + end +end + +@testset "OpenCL-specific unary - $T" for T in float_types + @testset "$f" for f in [ + OpenCL.acospi, + OpenCL.asinpi, + OpenCL.atanpi, + OpenCL.logb, + OpenCL.rint, + OpenCL.rsqrt, + ] + x = rand(T) + broken = ispocl && T == Float16 && !(f in [OpenCL.rint, OpenCL.rsqrt]) + @test call_on_device(f, x) isa Real broken = broken # Just check it doesn't error + end + broken = ispocl && T == Float16 + @test call_on_device(OpenCL.ilogb, T(8.0)) isa Int32 broken = broken + @test call_on_device(OpenCL.nan, Base.uinttype(T)(0)) isa T end +@testset "OpenCL-specific binary - $T" for T in float_types + @testset "$f" for f in [ + OpenCL.atanpi, + OpenCL.dim, + OpenCL.maxmag, + OpenCL.minmag, + OpenCL.nextafter, + OpenCL.powr, + ] + x = rand(T) + y = rand(T) + broken = ispocl && T == Float16 && !(f in [OpenCL.maxmag, OpenCL.minmag]) + @test call_on_device(f, x, y) isa Real broken = broken # Just check it doesn't error + end + broken = ispocl && T == Float16 + @test call_on_device(OpenCL.rootn, T(8.0), Int32(3)) ≈ T(2.0) broken = broken +end + +@testset "OpenCL-specific ternary - $T" for T in float_types + x = rand(T) + y = rand(T) + z = rand(T) + @test call_on_device(OpenCL.mad, x, y, z) ≈ x * y + z +end + +@testset "SIMD - $N x $T" for N in simd_ns, T in float_types + v = Vec{N, T}(ntuple(_ -> rand(T), N)) + + # unary ops: sin, cos, sqrt + a = call_on_device(sin, v) + @test all(a[i] ≈ sin(v[i]) for i in 1:N) + + b = call_on_device(cos, v) + @test all(b[i] ≈ cos(v[i]) for i in 1:N) + + c = call_on_device(sqrt, v) + @test all(c[i] ≈ sqrt(v[i]) for i in 1:N) + + # binary ops: max, hypot + w = Vec{N, T}(ntuple(_ -> rand(T), N)) + d = call_on_device(max, v, w) + @test all(d[i] == max(v[i], w[i]) for i in 1:N) + + broken = ispocl && T == Float16 + if !broken + h = call_on_device(hypot, v, w) + @test all(h[i] ≈ hypot(v[i], w[i]) for i in 1:N) + end + + # ternary op: fma + x = Vec{N, T}(ntuple(_ -> rand(T), N)) + e = call_on_device(fma, v, w, x) + @test all(e[i] ≈ fma(v[i], w[i], x[i]) for i in 1:N) + + # special cases: ilogb, ldexp, ^ with Int32, rootn + v_pos = Vec{N, T}(ntuple(_ -> rand(T) + T(1), N)) + @test call_on_device(OpenCL.ilogb, v_pos) isa Vec{N, Int32} broken = broken + + k = Vec{N, Int32}(ntuple(_ -> rand(Int32.(-5:5)), N)) + @test let + ldexp_result = call_on_device(ldexp, v_pos, k) + all(ldexp_result[i] ≈ ldexp(v_pos[i], k[i]) for i in 1:N) + end broken = broken + + base = Vec{N, T}(ntuple(_ -> rand(T) + T(0.5), N)) + exp_int = Vec{N, Int32}(ntuple(_ -> rand(Int32.(0:3)), N)) + @test let + pow_result = call_on_device(^, base, exp_int) + all(pow_result[i] ≈ base[i] ^ exp_int[i] for i in 1:N) + end broken = broken + + rootn_base = Vec{N, T}(ntuple(_ -> rand(T) * T(10) + T(1), N)) + rootn_n = Vec{N, Int32}(ntuple(_ -> rand(Int32.(2:4)), N)) + @test call_on_device(OpenCL.rootn, rootn_base, rootn_n) isa Vec{N, T} broken = broken + + # special cases: nan + nan_code = Vec{N, Base.uinttype(T)}(ntuple(_ -> rand(Base.uinttype(T)), N)) + nan_result = call_on_device(OpenCL.nan, nan_code) + @test all(isnan(nan_result[i]) for i in 1:N) +end + +end + +end