Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions lib/intrinsics/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,16 @@ GPUToolbox = "096a3bc2-3ced-46d0-87f4-dd12716f4bfc"
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[weakdeps]
SIMD = "fdea26ae-647d-5447-a871-4b548cad5224"

[extensions]
SPIRVIntrinsicsSIMDExt = "SIMD"

[compat]
ExprTools = "0.1"
GPUToolbox = "0.2, 0.3"
LLVM = "9.1"
SIMD = "3.6"
SpecialFunctions = "1.3, 2"
julia = "1.10"
121 changes: 121 additions & 0 deletions lib/intrinsics/ext/SPIRVIntrinsicsSIMDExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
module SPIRVIntrinsicsSIMDExt

using SPIRVIntrinsics
using SPIRVIntrinsics: @device_override, @device_function, @builtin_ccall, @typed_ccall
using SIMD
import SpecialFunctions

const known_intrinsics = String[]

# Generate vectorized math intrinsics
for N in [2, 3, 4, 8, 16], T in [Float16, Float32, Float64]
VT = :(Vec{$N,$T})
LVT = :(SIMD.LVec{$N,$T})

@eval begin
# Unary operations
@device_override Base.acos(x::$VT) = $VT(@builtin_ccall("acos", $LVT, ($LVT,), x.data))
@device_override Base.acosh(x::$VT) = $VT(@builtin_ccall("acosh", $LVT, ($LVT,), x.data))
@device_function SPIRVIntrinsics.acospi(x::$VT) = $VT(@builtin_ccall("acospi", $LVT, ($LVT,), x.data))

@device_override Base.asin(x::$VT) = $VT(@builtin_ccall("asin", $LVT, ($LVT,), x.data))
@device_override Base.asinh(x::$VT) = $VT(@builtin_ccall("asinh", $LVT, ($LVT,), x.data))
@device_function SPIRVIntrinsics.asinpi(x::$VT) = $VT(@builtin_ccall("asinpi", $LVT, ($LVT,), x.data))

@device_override Base.atan(x::$VT) = $VT(@builtin_ccall("atan", $LVT, ($LVT,), x.data))
@device_override Base.atanh(x::$VT) = $VT(@builtin_ccall("atanh", $LVT, ($LVT,), x.data))
@device_function SPIRVIntrinsics.atanpi(x::$VT) = $VT(@builtin_ccall("atanpi", $LVT, ($LVT,), x.data))

@device_override Base.cbrt(x::$VT) = $VT(@builtin_ccall("cbrt", $LVT, ($LVT,), x.data))
@device_override Base.ceil(x::$VT) = $VT(@builtin_ccall("ceil", $LVT, ($LVT,), x.data))

@device_override Base.cos(x::$VT) = $VT(@builtin_ccall("cos", $LVT, ($LVT,), x.data))
@device_override Base.cosh(x::$VT) = $VT(@builtin_ccall("cosh", $LVT, ($LVT,), x.data))
@device_override Base.cospi(x::$VT) = $VT(@builtin_ccall("cospi", $LVT, ($LVT,), x.data))

@device_override SpecialFunctions.erfc(x::$VT) = $VT(@builtin_ccall("erfc", $LVT, ($LVT,), x.data))
@device_override SpecialFunctions.erf(x::$VT) = $VT(@builtin_ccall("erf", $LVT, ($LVT,), x.data))

@device_override Base.exp(x::$VT) = $VT(@builtin_ccall("exp", $LVT, ($LVT,), x.data))
@device_override Base.exp2(x::$VT) = $VT(@builtin_ccall("exp2", $LVT, ($LVT,), x.data))
@device_override Base.exp10(x::$VT) = $VT(@builtin_ccall("exp10", $LVT, ($LVT,), x.data))
@device_override Base.expm1(x::$VT) = $VT(@builtin_ccall("expm1", $LVT, ($LVT,), x.data))

@device_override Base.abs(x::$VT) = $VT(@builtin_ccall("fabs", $LVT, ($LVT,), x.data))
@device_override Base.floor(x::$VT) = $VT(@builtin_ccall("floor", $LVT, ($LVT,), x.data))

@device_override SpecialFunctions.loggamma(x::$VT) = $VT(@builtin_ccall("lgamma", $LVT, ($LVT,), x.data))

@device_override Base.log(x::$VT) = $VT(@builtin_ccall("log", $LVT, ($LVT,), x.data))
@device_override Base.log2(x::$VT) = $VT(@builtin_ccall("log2", $LVT, ($LVT,), x.data))
@device_override Base.log10(x::$VT) = $VT(@builtin_ccall("log10", $LVT, ($LVT,), x.data))
@device_override Base.log1p(x::$VT) = $VT(@builtin_ccall("log1p", $LVT, ($LVT,), x.data))
@device_function SPIRVIntrinsics.logb(x::$VT) = $VT(@builtin_ccall("logb", $LVT, ($LVT,), x.data))

@device_function SPIRVIntrinsics.rint(x::$VT) = $VT(@builtin_ccall("rint", $LVT, ($LVT,), x.data))
@device_override Base.round(x::$VT) = $VT(@builtin_ccall("round", $LVT, ($LVT,), x.data))
@device_function SPIRVIntrinsics.rsqrt(x::$VT) = $VT(@builtin_ccall("rsqrt", $LVT, ($LVT,), x.data))

@device_override Base.sin(x::$VT) = $VT(@builtin_ccall("sin", $LVT, ($LVT,), x.data))
@device_override Base.sinh(x::$VT) = $VT(@builtin_ccall("sinh", $LVT, ($LVT,), x.data))
@device_override Base.sinpi(x::$VT) = $VT(@builtin_ccall("sinpi", $LVT, ($LVT,), x.data))

@device_override Base.sqrt(x::$VT) = $VT(@builtin_ccall("sqrt", $LVT, ($LVT,), x.data))

@device_override Base.tan(x::$VT) = $VT(@builtin_ccall("tan", $LVT, ($LVT,), x.data))
@device_override Base.tanh(x::$VT) = $VT(@builtin_ccall("tanh", $LVT, ($LVT,), x.data))
@device_override Base.tanpi(x::$VT) = $VT(@builtin_ccall("tanpi", $LVT, ($LVT,), x.data))

@device_override SpecialFunctions.gamma(x::$VT) = $VT(@builtin_ccall("tgamma", $LVT, ($LVT,), x.data))

@device_override Base.trunc(x::$VT) = $VT(@builtin_ccall("trunc", $LVT, ($LVT,), x.data))

# Binary operations
@device_override Base.atan(y::$VT, x::$VT) = $VT(@builtin_ccall("atan2", $LVT, ($LVT, $LVT), y.data, x.data))
@device_function SPIRVIntrinsics.atanpi(y::$VT, x::$VT) = $VT(@builtin_ccall("atan2pi", $LVT, ($LVT, $LVT), y.data, x.data))

@device_override Base.copysign(x::$VT, y::$VT) = $VT(@builtin_ccall("copysign", $LVT, ($LVT, $LVT), x.data, y.data))
@device_function SPIRVIntrinsics.dim(x::$VT, y::$VT) = $VT(@builtin_ccall("fdim", $LVT, ($LVT, $LVT), x.data, y.data))

@device_override Base.hypot(x::$VT, y::$VT) = $VT(@builtin_ccall("hypot", $LVT, ($LVT, $LVT), x.data, y.data))

@device_override Base.max(x::$VT, y::$VT) = $VT(@builtin_ccall("fmax", $LVT, ($LVT, $LVT), x.data, y.data))
@device_override Base.min(x::$VT, y::$VT) = $VT(@builtin_ccall("fmin", $LVT, ($LVT, $LVT), x.data, y.data))

@device_function SPIRVIntrinsics.maxmag(x::$VT, y::$VT) = $VT(@builtin_ccall("maxmag", $LVT, ($LVT, $LVT), x.data, y.data))
@device_function SPIRVIntrinsics.minmag(x::$VT, y::$VT) = $VT(@builtin_ccall("minmag", $LVT, ($LVT, $LVT), x.data, y.data))

@device_function SPIRVIntrinsics.nextafter(x::$VT, y::$VT) = $VT(@builtin_ccall("nextafter", $LVT, ($LVT, $LVT), x.data, y.data))

@device_override Base.:(^)(x::$VT, y::$VT) = $VT(@builtin_ccall("pow", $LVT, ($LVT, $LVT), x.data, y.data))
@device_function SPIRVIntrinsics.powr(x::$VT, y::$VT) = $VT(@builtin_ccall("powr", $LVT, ($LVT, $LVT), x.data, y.data))

@device_override Base.rem(x::$VT, y::$VT) = $VT(@builtin_ccall("remainder", $LVT, ($LVT, $LVT), x.data, y.data))

# Ternary operations
@device_override Base.fma(a::$VT, b::$VT, c::$VT) = $VT(@builtin_ccall("fma", $LVT, ($LVT, $LVT, $LVT), a.data, b.data, c.data))
@device_function SPIRVIntrinsics.mad(a::$VT, b::$VT, c::$VT) = $VT(@builtin_ccall("mad", $LVT, ($LVT, $LVT, $LVT), a.data, b.data, c.data))
end

# Special operations with Int32 parameters
VIntT = :(Vec{$N,Int32})
LVIntT = :(SIMD.LVec{$N,Int32})

@eval begin
@device_function SPIRVIntrinsics.ilogb(x::$VT) = $VIntT(@builtin_ccall("ilogb", $LVIntT, ($LVT,), x.data))
@device_override Base.ldexp(x::$VT, k::$VIntT) = $VT(@builtin_ccall("ldexp", $LVT, ($LVT, $LVIntT), x.data, k.data))
@device_override Base.:(^)(x::$VT, y::$VIntT) = $VT(@builtin_ccall("pown", $LVT, ($LVT, $LVIntT), x.data, y.data))
@device_function SPIRVIntrinsics.rootn(x::$VT, y::$VIntT) = $VT(@builtin_ccall("rootn", $LVT, ($LVT, $LVIntT), x.data, y.data))
end
end

# nan functions - take unsigned integer codes and return floats
for N in [2, 3, 4, 8, 16]
@eval begin
@device_function SPIRVIntrinsics.nan(nancode::Vec{$N,UInt16}) = Vec{$N,Float16}(@builtin_ccall("nan", SIMD.LVec{$N,Float16}, (SIMD.LVec{$N,UInt16},), nancode.data))
@device_function SPIRVIntrinsics.nan(nancode::Vec{$N,UInt32}) = Vec{$N,Float32}(@builtin_ccall("nan", SIMD.LVec{$N,Float32}, (SIMD.LVec{$N,UInt32},), nancode.data))
@device_function SPIRVIntrinsics.nan(nancode::Vec{$N,UInt64}) = Vec{$N,Float64}(@builtin_ccall("nan", SIMD.LVec{$N,Float64}, (SIMD.LVec{$N,UInt64},), nancode.data))
end
end

end # module
22 changes: 14 additions & 8 deletions lib/intrinsics/src/math.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Math Functions

# TODO: vector types
const generic_types = [Float32,Float64]
const generic_types = [Float16, Float32, Float64]
const generic_types_float = [Float32]
const generic_types_double = [Float64]

Expand Down Expand Up @@ -33,7 +33,7 @@ for gentype in generic_types

@device_override Base.cos(x::$gentype) = @builtin_ccall("cos", $gentype, ($gentype,), x)
@device_override Base.cosh(x::$gentype) = @builtin_ccall("cosh", $gentype, ($gentype,), x)
@device_function cospi(x::$gentype) = @builtin_ccall("cospi", $gentype, ($gentype,), x)
@device_override Base.cospi(x::$gentype) = @builtin_ccall("cospi", $gentype, ($gentype,), x)

@device_override SpecialFunctions.erfc(x::$gentype) = @builtin_ccall("erfc", $gentype, ($gentype,), x)
@device_override SpecialFunctions.erf(x::$gentype) = @builtin_ccall("erf", $gentype, ($gentype,), x)
Expand All @@ -59,7 +59,10 @@ for gentype in generic_types
#@device_override Base.mod(x::$gentype, y::$gentype) = @builtin_ccall("fmod", $gentype, ($gentype, $gentype), x, y)
# fract(x::$gentype, $gentype *iptr) = @builtin_ccall("fract", $gentype, ($gentype, $gentype *), x, iptr)

@device_override Base.hypot(x::$gentype, y::$gentype) = @builtin_ccall("hypot", $gentype, ($gentype, $gentype), x, y)
# TODO: remove once https://github.com/pocl/pocl/issues/2034 is addressed
if $gentype != Float16
@device_override Base.hypot(x::$gentype, y::$gentype) = @builtin_ccall("hypot", $gentype, ($gentype, $gentype), x, y)
end

@device_override SpecialFunctions.loggamma(x::$gentype) = @builtin_ccall("lgamma", $gentype, ($gentype,), x)

Expand All @@ -81,8 +84,6 @@ for gentype in generic_types
@device_override Base.:(^)(x::$gentype, y::$gentype) = @builtin_ccall("pow", $gentype, ($gentype, $gentype), x, y)
@device_function powr(x::$gentype, y::$gentype) = @builtin_ccall("powr", $gentype, ($gentype, $gentype), x, y)

@device_override Base.rem(x::$gentype, y::$gentype) = @builtin_ccall("remainder", $gentype, ($gentype, $gentype), x, y)

@device_function rint(x::$gentype) = @builtin_ccall("rint", $gentype, ($gentype,), x)

@device_override Base.round(x::$gentype) = @builtin_ccall("round", $gentype, ($gentype,), x)
Expand All @@ -100,13 +101,13 @@ for gentype in generic_types
return sinval, cosval[]
end
@device_override Base.sinh(x::$gentype) = @builtin_ccall("sinh", $gentype, ($gentype,), x)
@device_function sinpi(x::$gentype) = @builtin_ccall("sinpi", $gentype, ($gentype,), x)
@device_override Base.sinpi(x::$gentype) = @builtin_ccall("sinpi", $gentype, ($gentype,), x)

@device_override Base.sqrt(x::$gentype) = @builtin_ccall("sqrt", $gentype, ($gentype,), x)

@device_override Base.tan(x::$gentype) = @builtin_ccall("tan", $gentype, ($gentype,), x)
@device_override Base.tanh(x::$gentype) = @builtin_ccall("tanh", $gentype, ($gentype,), x)
@device_function tanpi(x::$gentype) = @builtin_ccall("tanpi", $gentype, ($gentype,), x)
@device_override Base.tanpi(x::$gentype) = @builtin_ccall("tanpi", $gentype, ($gentype,), x)

@device_override SpecialFunctions.gamma(x::$gentype) = @builtin_ccall("tgamma", $gentype, ($gentype,), x)

Expand Down Expand Up @@ -151,11 +152,13 @@ end
# frexp(x::Float64{n}, Int32{n} *exp) = @builtin_ccall("frexp", Float64{n}, (Float64{n}, Int32{n} *), x, exp)
# frexp(x::Float64, Int32 *exp) = @builtin_ccall("frexp", Float64, (Float64, Int32 *), x, exp)

@device_function ilogb(x::Float16) = @builtin_ccall("ilogb", Int32, (Float16,), x)
# ilogb(x::Float32{n}) = @builtin_ccall("ilogb", Int32{n}, (Float32{n},), x)
@device_function ilogb(x::Float32) = @builtin_ccall("ilogb", Int32, (Float32,), x)
# ilogb(x::Float64{n}) = @builtin_ccall("ilogb", Int32{n}, (Float64{n},), x)
@device_function ilogb(x::Float64) = @builtin_ccall("ilogb", Int32, (Float64,), x)

@device_override Base.ldexp(x::Float16, k::Int32) = @builtin_ccall("ldexp", Float16, (Float16, Int32), x, k)
# ldexp(x::Float32{n}, k::Int32{n}) = @builtin_ccall("ldexp", Float32{n}, (Float32{n}, Int32{n}), x, k)
# ldexp(x::Float32{n}, k::Int32) = @builtin_ccall("ldexp", Float32{n}, (Float32{n}, Int32), x, k)
@device_override Base.ldexp(x::Float32, k::Int32) = @builtin_ccall("ldexp", Float32, (Float32, Int32), x, k)
Expand All @@ -168,11 +171,13 @@ end
# lgamma_r(x::Float64{n}, Int32{n} *signp) = @builtin_ccall("lgamma_r", Float64{n}, (Float64{n}, Int32{n} *), x, signp)
# Float64 lgamma_r(x::Float64, Int32 *signp) = @builtin_ccall("lgamma_r", Float64, (Float64, Int32 *), x, signp)

@device_function nan(nancode::UInt16) = @builtin_ccall("nan", Float16, (UInt16,), nancode)
# nan(nancode::uintn) = @builtin_ccall("nan", Float32{n}, (uintn,), nancode)
@device_function nan(nancode::UInt32) = @builtin_ccall("nan", Float32, (UInt32,), nancode)
# nan(nancode::UInt64{n}) = @builtin_ccall("nan", Float64{n}, (UInt64{n},), nancode)
@device_function nan(nancode::UInt64) = @builtin_ccall("nan", Float64, (UInt64,), nancode)

@device_override Base.:(^)(x::Float16, y::Int32) = @builtin_ccall("pown", Float16, (Float16, Int32), x, y)
# pown(x::Float32{n}, y::Int32{n}) = @builtin_ccall("pown", Float32{n}, (Float32{n}, Int32{n}), x, y)
@device_override Base.:(^)(x::Float32, y::Int32) = @builtin_ccall("pown", Float32, (Float32, Int32), x, y)
# pown(x::Float64{n}, y::Int32{n}) = @builtin_ccall("pown", Float64{n}, (Float64{n}, Int32{n}), x, y)
Expand All @@ -183,10 +188,11 @@ end
# remquo(x::Float64{n}, y::Float64{n}, Int32{n} *quo) = @builtin_ccall("remquo", Float64{n}, (Float64{n}, Float64{n}, Int32{n} *), x, y, quo)
# remquo(x::Float64, y::Float64, Int32 *quo) = @builtin_ccall("remquo", Float64, (Float64, Float64, Int32 *), x, y, quo)

@device_function rootn(x::Float16, y::Int32) = @builtin_ccall("rootn", Float16, (Float16, Int32), x, y)
# rootn(x::Float32{n}, y::Int32{n}) = @builtin_ccall("rootn", Float32{n}, (Float32{n}, Int32{n}), x, y)
@device_function rootn(x::Float32, y::Int32) = @builtin_ccall("rootn", Float32, (Float32, Int32), x, y)
# rootn(x::Float64{n}, y::Int32{n}) = @builtin_ccall("rootn", Float64{n}, (Float64{n}, Int32{n}), x, y)
# rootn(x::Float64, y::Int32) = @builtin_ccall("rootn", Float64{n}, (Float64, Int32), x, y)
@device_function rootn(x::Float64, y::Int32) = @builtin_ccall("rootn", Float64, (Float64, Int32), x, y)


# TODO: half and native
Expand Down
9 changes: 6 additions & 3 deletions lib/intrinsics/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ macro builtin_ccall(name, ret, argtypes, args...)
"c"
elseif T == UInt8
"h"
elseif T == Float16
"Dh"
elseif T == Float32
"f"
elseif T == Float64
Expand Down Expand Up @@ -61,17 +63,18 @@ macro builtin_ccall(name, ret, argtypes, args...)
error("Unknown type $T")
end
end
mangle(::Type{NTuple{N, VecElement{T}}}) where {N, T} = "Dv$(N)_" * mangle(T)

# C++-style mangling; very limited to just support these intrinsics
# TODO: generalize for use with other intrinsics? do we need to mangle those?
mangled = "_Z$(length(name))$name"
for t in argtypes
# with `@eval @builtin_ccall`, we get actual types in the ast, otherwise symbols
t = (isa(t, Symbol) || isa(t, Expr)) ? eval(t) : t
t = (isa(t, Symbol) || isa(t, Expr)) ? __module__.eval(t) : t
mangled *= mangle(t)
end

push!(known_intrinsics, mangled)
push!(__module__.known_intrinsics, mangled)
esc(quote
@typed_ccall($mangled, llvmcall, $ret, ($(argtypes...),), $(args...))
end)
Expand All @@ -85,7 +88,7 @@ Base.Experimental.@MethodTable(method_table)

macro device_override(ex)
esc(quote
Base.Experimental.@overlay(method_table, $ex)
Base.Experimental.@overlay($method_table, $ex)
end)
end

Expand Down
3 changes: 3 additions & 0 deletions src/compiler/compilation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ GPUCompiler.isintrinsic(job::OpenCLCompilerJob, fn::String) =
Tuple{CompilerJob{SPIRVCompilerTarget}, typeof(fn)},
job, fn) ||
in(fn, known_intrinsics) ||
let SPIRVIntrinsicsSIMDExt = Base.get_extension(SPIRVIntrinsics, :SPIRVIntrinsicsSIMDExt)
SPIRVIntrinsicsSIMDExt !== nothing && in(fn, SPIRVIntrinsicsSIMDExt.known_intrinsics)
end ||
contains(fn, "__spirv_")


Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SIMD = "fdea26ae-647d-5447-a871-4b548cad5224"
SPIRVIntrinsics = "71d1d633-e7e8-4a92-83a1-de8814b09ba8"
SPIRV_LLVM_Backend_jll = "4376b9bf-cff8-51b6-bb48-39421dff0d0c"
SPIRV_LLVM_Translator_jll = "4a5d46fc-d8cf-5151-a261-86b458210efb"
Expand Down
2 changes: 1 addition & 1 deletion test/atomics.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using SPIRVIntrinsics: @builtin_ccall, @typed_ccall, LLVMPtr
using SPIRVIntrinsics: @builtin_ccall, @typed_ccall, LLVMPtr, known_intrinsics

@testset "atomics" begin

Expand Down
Loading
Loading