From e2753ccaadf0d684a1d0ba0e7c9dd3e66c6e221a Mon Sep 17 00:00:00 2001 From: Anton Oresten Date: Sun, 17 Aug 2025 18:38:27 +0200 Subject: [PATCH 1/3] Add batched_mul support for oneAPI --- Project.toml | 2 ++ ext/NNliboneAPIExt.jl | 28 ++++++++++++++++++++++++++++ test/ext_oneapi/batched_mul.jl | 34 ++++++++++++++++++++++++++++++++++ test/ext_oneapi/runtests.jl | 5 +++++ test/runtests.jl | 20 +++++++++++++++++++- 5 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 ext/NNliboneAPIExt.jl create mode 100644 test/ext_oneapi/batched_mul.jl create mode 100644 test/ext_oneapi/runtests.jl diff --git a/Project.toml b/Project.toml index cc0d8fca..13fe1ede 100644 --- a/Project.toml +++ b/Project.toml @@ -21,6 +21,7 @@ FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" +oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" [extensions] NNlibAMDGPUExt = "AMDGPU" @@ -30,6 +31,7 @@ NNlibEnzymeCoreExt = "EnzymeCore" NNlibFFTWExt = "FFTW" NNlibForwardDiffExt = "ForwardDiff" NNlibSpecialFunctionsExt = "SpecialFunctions" +NNliboneAPIExt = "oneAPI" [compat] AMDGPU = "1, 2" diff --git a/ext/NNliboneAPIExt.jl b/ext/NNliboneAPIExt.jl new file mode 100644 index 00000000..643f22b8 --- /dev/null +++ b/ext/NNliboneAPIExt.jl @@ -0,0 +1,28 @@ +module NNliboneAPIExt + +using NNlib +using oneAPI + +function NNlib._batched_gemm!(::Type{<:oneArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) + oneAPI.oneMKL.gemm_strided_batched!(transA, transB, α, A, B, β, C) +end + +using NNlib: BatchedAdjoint, BatchedTranspose, BatchedAdjOrTrans +using Adapt +using Adapt: WrappedArray + +const oneAPIBatchedAdjoint{T} = BatchedAdjoint{T, <: oneArray{T}} +const oneAPIBatchedTranspose{T} = BatchedTranspose{T, <: oneArray{T}} +const oneAPIBatchedAdjOrTrans{T} = Union{oneAPIBatchedAdjoint{T}, oneAPIBatchedTranspose{T}} +const WrappedoneAPIBatchedAdjOrTrans{T, N} = WrappedArray{T, N, oneAPIBatchedAdjOrTrans{T}, oneAPIBatchedAdjOrTrans{T}} + +Base.print_array(io::IO, b::Union{oneAPIBatchedAdjOrTrans, WrappedoneAPIBatchedAdjOrTrans}) = Base.print_array(io, adapt(Array, b)) +Base._show_nonempty(io::IO, b::Union{oneAPIBatchedAdjOrTrans, WrappedoneAPIBatchedAdjOrTrans}, prefix::String) = Base._show_nonempty(io, adapt(Array, b), prefix) +Base.show_vector(io::IO, b::Union{oneAPIBatchedAdjOrTrans, WrappedoneAPIBatchedAdjOrTrans}, opn, cls) = Base.show_vector(io, adapt(Array, b), opn, cls) + +Base.convert(::Type{T}, b::Union{oneAPIBatchedAdjOrTrans, WrappedoneAPIBatchedAdjOrTrans}) where {T<:Array} = Base.convert(T, adapt(Array, b)) +Base.Array{T, N}(b::Union{oneAPIBatchedAdjOrTrans, WrappedoneAPIBatchedAdjOrTrans}) where {T, N} = Array{T, N}(adapt(Array, b)) +Base.collect(b::Union{oneAPIBatchedAdjOrTrans, WrappedoneAPIBatchedAdjOrTrans}) = collect(adapt(Array, b)) + + +end # module NNliboneAPIExt \ No newline at end of file diff --git a/test/ext_oneapi/batched_mul.jl b/test/ext_oneapi/batched_mul.jl new file mode 100644 index 00000000..c0c01918 --- /dev/null +++ b/test/ext_oneapi/batched_mul.jl @@ -0,0 +1,34 @@ +@testset "batched_mul" begin + A = rand(Float32, 3, 3, 2) + B = rand(Float32, 3, 3, 2) + dA, dB = oneArray.((A, B)) + + C = batched_mul(A, B) + @test oneArray(C) ≈ batched_mul(dA, dB) + + Ct = batched_mul(batched_transpose(A), B) + @test oneArray(Ct) ≈ batched_mul(batched_transpose(dA), dB) + + Ca = batched_mul(A, batched_adjoint(B)) + @test oneArray(Ca) ≈ batched_mul(dA, batched_adjoint(dB)) + + # 5-arg batched_mul! + C .= pi + batched_mul!(C, A, B, 2f0, 3f0) + Cpi = oneArray(similar(C)) .= pi + @test oneArray(C) ≈ batched_mul!(Cpi, dA, dB, 2f0, 3f0) + + # PermutedDimsArray + @test oneArray(Ct) ≈ batched_mul(PermutedDimsArray(dA, (2, 1, 3)), dB) + + # FIXME same but with (1, 3, 2) errors + D = permutedims(B, (2, 1, 3)) + Cp = batched_mul(batched_adjoint(A), B) + @test oneArray(Cp) ≈ batched_mul( + batched_adjoint(dA), PermutedDimsArray(oneArray(D), (2, 1, 3))) + + # Methods which reshape + M = randn(Float32, 3, 3) + Cm = batched_mul(A, M) + @test oneArray(Cm) ≈ batched_mul(dA, oneArray(M)) +end diff --git a/test/ext_oneapi/runtests.jl b/test/ext_oneapi/runtests.jl new file mode 100644 index 00000000..0459cc46 --- /dev/null +++ b/test/ext_oneapi/runtests.jl @@ -0,0 +1,5 @@ +oneAPI.allowscalar(false) + +@testset "Batched multiplication" begin + include("batched_mul.jl") +end diff --git a/test/runtests.jl b/test/runtests.jl index b8080b6b..c2daedfc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,7 +24,8 @@ DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib, UnicodePlots); recursiv # ENV["NNLIB_TEST_CUDA"] = "true" # uncomment to run CUDA tests # ENV["NNLIB_TEST_AMDGPU"] = "true" # uncomment to run AMDGPU tests -# ENV["NNLIB_TEST_CPU"] = "false" # uncomment to skip CPU tests +ENV["NNLIB_TEST_ONEAPI"] = "true" # uncomment to run oneAPI tests +ENV["NNLIB_TEST_CPU"] = "false" # uncomment to skip CPU tests const rng = StableRNG(123) include("test_utils.jl") @@ -184,4 +185,21 @@ end else @info "Skipping AMDGPU tests, set NNLIB_TEST_AMDGPU=true to run them." end + + if get(ENV, "NNLIB_TEST_ONEAPI", "false") == "true" + Pkg.add("oneAPI") + + using oneAPI + if oneAPI.functional() + @testset "oneAPI" begin + # nnlib_testsuite(oneAPIBackend) + + include("ext_oneapi/runtests.jl") + end + else + @info "oneAPI.jl package is not functional. Skipping oneAPI tests." + end + else + @info "Skipping oneAPI tests, set NNLIB_TEST_ONEAPI=true to run them." + end end From ba7ce5ab56507391d0c0229de3c362c70b10c263 Mon Sep 17 00:00:00 2001 From: Anton Oresten Date: Sun, 17 Aug 2025 19:05:02 +0200 Subject: [PATCH 2/3] Add documentation note for oneAPI --- docs/src/index.md | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 78d09df4..934b49a2 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -8,11 +8,15 @@ GPU support is provided as package extensions. In order to load the extensions, ```julia using NNlib, CUDA, cuDNN ``` -for CUDA support, or +for CUDA support, ```julia using NNlib, AMDGPU ``` -for AMDGPU support. +for AMDGPU support, or +```julia +using NNlib, oneAPI +``` +for partial oneAPI support (particularly batched multiplication primitives). ## Threading From 3a818b06f55bd5ecb5f59b3ca8b53036b6a8f49c Mon Sep 17 00:00:00 2001 From: Anton Oresten Date: Sun, 17 Aug 2025 19:06:31 +0200 Subject: [PATCH 3/3] Update runtests.jl --- test/runtests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index c2daedfc..1d68d304 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,8 +24,8 @@ DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib, UnicodePlots); recursiv # ENV["NNLIB_TEST_CUDA"] = "true" # uncomment to run CUDA tests # ENV["NNLIB_TEST_AMDGPU"] = "true" # uncomment to run AMDGPU tests -ENV["NNLIB_TEST_ONEAPI"] = "true" # uncomment to run oneAPI tests -ENV["NNLIB_TEST_CPU"] = "false" # uncomment to skip CPU tests +# ENV["NNLIB_TEST_ONEAPI"] = "true" # uncomment to run oneAPI tests +# ENV["NNLIB_TEST_CPU"] = "false" # uncomment to skip CPU tests const rng = StableRNG(123) include("test_utils.jl")