Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"

[extensions]
NNlibAMDGPUExt = "AMDGPU"
Expand All @@ -30,6 +31,7 @@ NNlibEnzymeCoreExt = "EnzymeCore"
NNlibFFTWExt = "FFTW"
NNlibForwardDiffExt = "ForwardDiff"
NNlibSpecialFunctionsExt = "SpecialFunctions"
NNliboneAPIExt = "oneAPI"

[compat]
AMDGPU = "1, 2"
Expand Down
8 changes: 6 additions & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@ GPU support is provided as package extensions. In order to load the extensions,
```julia
using NNlib, CUDA, cuDNN
```
for CUDA support, or
for CUDA support,
```julia
using NNlib, AMDGPU
```
for AMDGPU support.
for AMDGPU support, or
```julia
using NNlib, oneAPI
```
for partial oneAPI support (particularly batched multiplication primitives).

## Threading

Expand Down
28 changes: 28 additions & 0 deletions ext/NNliboneAPIExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
module NNliboneAPIExt

using NNlib
using oneAPI

function NNlib._batched_gemm!(::Type{<:oneArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C)
oneAPI.oneMKL.gemm_strided_batched!(transA, transB, α, A, B, β, C)
end

using NNlib: BatchedAdjoint, BatchedTranspose, BatchedAdjOrTrans
using Adapt
using Adapt: WrappedArray

const oneAPIBatchedAdjoint{T} = BatchedAdjoint{T, <: oneArray{T}}
const oneAPIBatchedTranspose{T} = BatchedTranspose{T, <: oneArray{T}}
const oneAPIBatchedAdjOrTrans{T} = Union{oneAPIBatchedAdjoint{T}, oneAPIBatchedTranspose{T}}
const WrappedoneAPIBatchedAdjOrTrans{T, N} = WrappedArray{T, N, oneAPIBatchedAdjOrTrans{T}, oneAPIBatchedAdjOrTrans{T}}

Base.print_array(io::IO, b::Union{oneAPIBatchedAdjOrTrans, WrappedoneAPIBatchedAdjOrTrans}) = Base.print_array(io, adapt(Array, b))
Base._show_nonempty(io::IO, b::Union{oneAPIBatchedAdjOrTrans, WrappedoneAPIBatchedAdjOrTrans}, prefix::String) = Base._show_nonempty(io, adapt(Array, b), prefix)
Base.show_vector(io::IO, b::Union{oneAPIBatchedAdjOrTrans, WrappedoneAPIBatchedAdjOrTrans}, opn, cls) = Base.show_vector(io, adapt(Array, b), opn, cls)

Base.convert(::Type{T}, b::Union{oneAPIBatchedAdjOrTrans, WrappedoneAPIBatchedAdjOrTrans}) where {T<:Array} = Base.convert(T, adapt(Array, b))
Base.Array{T, N}(b::Union{oneAPIBatchedAdjOrTrans, WrappedoneAPIBatchedAdjOrTrans}) where {T, N} = Array{T, N}(adapt(Array, b))
Base.collect(b::Union{oneAPIBatchedAdjOrTrans, WrappedoneAPIBatchedAdjOrTrans}) = collect(adapt(Array, b))


end # module NNliboneAPIExt
34 changes: 34 additions & 0 deletions test/ext_oneapi/batched_mul.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
@testset "batched_mul" begin
A = rand(Float32, 3, 3, 2)
B = rand(Float32, 3, 3, 2)
dA, dB = oneArray.((A, B))

C = batched_mul(A, B)
@test oneArray(C) ≈ batched_mul(dA, dB)

Ct = batched_mul(batched_transpose(A), B)
@test oneArray(Ct) ≈ batched_mul(batched_transpose(dA), dB)

Ca = batched_mul(A, batched_adjoint(B))
@test oneArray(Ca) ≈ batched_mul(dA, batched_adjoint(dB))

# 5-arg batched_mul!
C .= pi
batched_mul!(C, A, B, 2f0, 3f0)
Cpi = oneArray(similar(C)) .= pi
@test oneArray(C) ≈ batched_mul!(Cpi, dA, dB, 2f0, 3f0)

# PermutedDimsArray
@test oneArray(Ct) ≈ batched_mul(PermutedDimsArray(dA, (2, 1, 3)), dB)

# FIXME same but with (1, 3, 2) errors
D = permutedims(B, (2, 1, 3))
Cp = batched_mul(batched_adjoint(A), B)
@test oneArray(Cp) ≈ batched_mul(
batched_adjoint(dA), PermutedDimsArray(oneArray(D), (2, 1, 3)))

# Methods which reshape
M = randn(Float32, 3, 3)
Cm = batched_mul(A, M)
@test oneArray(Cm) ≈ batched_mul(dA, oneArray(M))
end
5 changes: 5 additions & 0 deletions test/ext_oneapi/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
oneAPI.allowscalar(false)

@testset "Batched multiplication" begin
include("batched_mul.jl")
end
18 changes: 18 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib, UnicodePlots); recursiv

# ENV["NNLIB_TEST_CUDA"] = "true" # uncomment to run CUDA tests
# ENV["NNLIB_TEST_AMDGPU"] = "true" # uncomment to run AMDGPU tests
# ENV["NNLIB_TEST_ONEAPI"] = "true" # uncomment to run oneAPI tests
# ENV["NNLIB_TEST_CPU"] = "false" # uncomment to skip CPU tests

const rng = StableRNG(123)
Expand Down Expand Up @@ -184,4 +185,21 @@ end
else
@info "Skipping AMDGPU tests, set NNLIB_TEST_AMDGPU=true to run them."
end

if get(ENV, "NNLIB_TEST_ONEAPI", "false") == "true"
Pkg.add("oneAPI")

using oneAPI
if oneAPI.functional()
@testset "oneAPI" begin
# nnlib_testsuite(oneAPIBackend)

include("ext_oneapi/runtests.jl")
end
else
@info "oneAPI.jl package is not functional. Skipping oneAPI tests."
end
else
@info "Skipping oneAPI tests, set NNLIB_TEST_ONEAPI=true to run them."
end
end
Loading