diff --git a/Project.toml b/Project.toml index e43900af..77913d55 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ForwardDiff" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "1.0.1" +version = "1.0.2" [deps] CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950" @@ -13,12 +13,13 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [weakdeps] +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [extensions] +ForwardDiffGPUArraysCoreExt = "GPUArraysCore" ForwardDiffStaticArraysExt = "StaticArrays" [compat] @@ -27,22 +28,25 @@ CommonSubexpressions = "0.3" DiffResults = "1.1" DiffRules = "1.4" DiffTests = "0.1" +GPUArraysCore = "0.1, 0.2" IrrationalConstants = "0.1, 0.2" LogExpFunctions = "0.3" NaNMath = "1" Preferences = "1" SpecialFunctions = "1, 2" StaticArrays = "1.5" -julia = "1.6" +julia = "1.10" [extras] Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" +JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Calculus", "DiffTests", "IrrationalConstants", "SparseArrays", "StaticArrays", "Test", "InteractiveUtils"] +test = ["Calculus", "DiffTests", "IrrationalConstants", "SparseArrays", "StaticArrays", "Test", "InteractiveUtils", "JLArrays"] diff --git a/ext/ForwardDiffGPUArraysCoreExt.jl b/ext/ForwardDiffGPUArraysCoreExt.jl new file mode 100644 index 00000000..bf63da00 --- /dev/null +++ b/ext/ForwardDiffGPUArraysCoreExt.jl @@ -0,0 +1,67 @@ +module ForwardDiffGPUArraysCoreExt + +using GPUArraysCore: AbstractGPUArray +using ForwardDiff: ForwardDiff, Dual, Partials, npartials, partials + +struct PartialsFn{T,D<:Dual} + dual::D +end +PartialsFn{T}(dual::Dual) where {T} = PartialsFn{T,typeof(dual)}(dual) + +(f::PartialsFn{T})(i) where {T} = partials(T, f.dual, i) + +_take(itr, N::Integer) = Iterators.take(itr, min(length(itr), N)) + +function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x, + seed::Partials{N,V}) where {T,V,N} + idxs = collect(ForwardDiff.structural_eachindex(duals, x)) + duals[idxs] .= Dual{T,V,N}.(view(x, idxs), Ref(seed)) + return duals +end + +function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x, + seeds::NTuple{N,Partials{N,V}}) where {T,V,N} + idxs = collect(_take(ForwardDiff.structural_eachindex(duals, x), N)) + duals[idxs] .= Dual{T,V,N}.(view(x, idxs), getindex.(Ref(seeds), 1:length(idxs))) + return duals +end + +function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x, index, + seed::Partials{N,V}) where {T,V,N} + offset = index - 1 + idxs = collect(Iterators.drop(ForwardDiff.structural_eachindex(duals, x), offset)) + duals[idxs] .= Dual{T,V,N}.(view(x, idxs), Ref(seed)) + return duals +end + +function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x, index, + seeds::NTuple{N,Partials{N,V}}, chunksize) where {T,V,N} + offset = index - 1 + idxs = collect( + _take(Iterators.drop(ForwardDiff.structural_eachindex(duals, x), offset), chunksize) + ) + duals[idxs] .= Dual{T,V,N}.(view(x, idxs), getindex.(Ref(seeds), 1:length(idxs))) + return duals +end + +# gradient +function ForwardDiff.extract_gradient!(::Type{T}, result::AbstractGPUArray, + dual::Dual) where {T} + fn = PartialsFn{T}(dual) + idxs = collect(_take(ForwardDiff.structural_eachindex(result), npartials(dual))) + result[idxs] .= fn.(1:length(idxs)) + return result +end + +function ForwardDiff.extract_gradient_chunk!(::Type{T}, result::AbstractGPUArray, dual, + index, chunksize) where {T} + fn = PartialsFn{T}(dual) + offset = index - 1 + idxs = collect( + _take(Iterators.drop(ForwardDiff.structural_eachindex(result), offset), chunksize) + ) + result[idxs] .= fn.(1:length(idxs)) + return result +end + +end diff --git a/src/ForwardDiff.jl b/src/ForwardDiff.jl index fdfcd560..b16b986b 100644 --- a/src/ForwardDiff.jl +++ b/src/ForwardDiff.jl @@ -22,10 +22,6 @@ include("gradient.jl") include("jacobian.jl") include("hessian.jl") -if !isdefined(Base, :get_extension) - include("../ext/ForwardDiffStaticArraysExt.jl") -end - export DiffResults end # module diff --git a/test/GradientTest.jl b/test/GradientTest.jl index 4f46c167..5c2c0938 100644 --- a/test/GradientTest.jl +++ b/test/GradientTest.jl @@ -9,6 +9,7 @@ using ForwardDiff using ForwardDiff: Dual, Tag using StaticArrays using DiffTests +using JLArrays include(joinpath(dirname(@__FILE__), "utils.jl")) @@ -255,4 +256,25 @@ end end end +@testset "GPUArraysCore" begin + fn(x) = sum(x .^ 2 ./ 2) + + x = [1.0, 2.0, 3.0] + x_jl = JLArray(x) + + grad = ForwardDiff.gradient(fn, x) + grad_jl = ForwardDiff.gradient(fn, x_jl) + + @test grad_jl isa JLArray + @test Array(grad_jl) ≈ grad + + cfg = ForwardDiff.GradientConfig( + fn, x_jl, ForwardDiff.Chunk{2}(), ForwardDiff.Tag(fn, eltype(x)) + ) + grad_jl = ForwardDiff.gradient(fn, x_jl, cfg) + + @test grad_jl isa JLArray + @test Array(grad_jl) ≈ grad +end + end # module diff --git a/test/JacobianTest.jl b/test/JacobianTest.jl index 1e52f7fa..865503b5 100644 --- a/test/JacobianTest.jl +++ b/test/JacobianTest.jl @@ -8,6 +8,7 @@ using ForwardDiff: Dual, Tag, JacobianConfig using StaticArrays using DiffTests using LinearAlgebra +using JLArrays include(joinpath(dirname(@__FILE__), "utils.jl")) @@ -279,4 +280,17 @@ end end end +@testset "GPUArraysCore" begin + f(x) = x .^ 2 ./ 2 + + x = [1.0, 2.0, 3.0] + x_jl = JLArray(x) + + jac = ForwardDiff.jacobian(f, x) + jac_jl = ForwardDiff.jacobian(f, x_jl) + + @test jac_jl isa JLArray + @test Array(jac_jl) ≈ jac +end + end # module