Skip to content

Commit 22c7096

Browse files
committed
fix: support gradient + more test coverage
1 parent 4becd41 commit 22c7096

File tree

3 files changed

+56
-7
lines changed

3 files changed

+56
-7
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "1.0.2"
66
CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950"
77
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
88
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
9+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1112
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"

ext/ForwardDiffGPUArraysCoreExt.jl

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,62 @@
11
module ForwardDiffGPUArraysCoreExt
22

33
using GPUArraysCore: AbstractGPUArray
4-
using ForwardDiff: ForwardDiff, Dual, Partials
4+
using ForwardDiff: ForwardDiff, Dual, Partials, npartials, partials
55

66
function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x,
77
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
8-
idxs = ForwardDiff.structural_eachindex(duals, x)
9-
duals[idxs] .= Dual{T,V,N}.(x[idxs], seed)
8+
idxs = collect(ForwardDiff.structural_eachindex(duals, x))
9+
duals[idxs] .= Dual{T,V,N}.(x[idxs], Ref(seed))
1010
return duals
1111
end
1212

1313
function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x,
1414
seeds::NTuple{N,Partials{N,V}}) where {T,V,N}
15-
idxs = ForwardDiff.structural_eachindex(duals, x)
15+
idxs = collect(ForwardDiff.structural_eachindex(duals, x))[1:N]
1616
duals[idxs] .= Dual{T,V,N}.(x[idxs], seeds[1:N])
1717
return duals
1818
end
1919

2020
function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x, index,
2121
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
2222
offset = index - 1
23-
idxs = Iterators.drop(ForwardDiff.structural_eachindex(duals, x), offset)
24-
duals[idxs] .= Dual{T,V,N}.(x[idxs], seed)
23+
idxs = collect(Iterators.drop(ForwardDiff.structural_eachindex(duals, x), offset))
24+
duals[idxs] .= Dual{T,V,N}.(x[idxs], Ref(seed))
2525
return duals
2626
end
2727

2828
function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x, index,
2929
seeds::NTuple{N,Partials{N,V}}, chunksize = N) where {T,V,N}
3030
offset = index - 1
31-
idxs = Iterators.drop(ForwardDiff.structural_eachindex(duals, x), offset)
31+
idxs = collect(
32+
Iterators.drop(ForwardDiff.structural_eachindex(duals, x), offset)
33+
)[1:chunksize]
3234
duals[idxs] .= Dual{T,V,N}.(x[idxs], seeds[1:chunksize])
3335
return duals
3436
end
3537

38+
# gradient
39+
function ForwardDiff.extract_gradient!(::Type{T}, result::AbstractGPUArray,
40+
dual::Dual) where {T}
41+
# this closure is needed for gpu compilation
42+
partial_fn(dual, i) = partials(T, dual, i)
43+
44+
idxs = ForwardDiff.structural_eachindex(result)
45+
result[idxs] .= partial_fn.(Ref(dual), 1:npartials(dual))
46+
return result
47+
end
48+
49+
function ForwardDiff.extract_gradient_chunk!(::Type{T}, result::AbstractGPUArray, dual,
50+
index, chunksize) where {T}
51+
# this closure is needed for gpu compilation
52+
partial_fn(dual, i) = partials(T, dual, i)
53+
54+
offset = index - 1
55+
idxs = collect(
56+
Iterators.drop(ForwardDiff.structural_eachindex(result), offset)
57+
)[1:chunksize]
58+
result[idxs] .= partial_fn.(Ref(dual), 1:chunksize)
59+
return result
60+
end
61+
3662
end

test/GradientTest.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using ForwardDiff
99
using ForwardDiff: Dual, Tag
1010
using StaticArrays
1111
using DiffTests
12+
using JLArrays
1213

1314
include(joinpath(dirname(@__FILE__), "utils.jl"))
1415

@@ -255,4 +256,25 @@ end
255256
end
256257
end
257258

259+
@testset "GPUArraysCore" begin
260+
fn(x) = sum(x .^ 2 ./ 2)
261+
262+
x = [1.0, 2.0, 3.0]
263+
x_jl = JLArray(x)
264+
265+
grad = ForwardDiff.gradient(fn, x)
266+
grad_jl = ForwardDiff.gradient(fn, x_jl)
267+
268+
@test grad_jl isa JLArray
269+
@test Array(grad_jl) grad
270+
271+
cfg = ForwardDiff.GradientConfig(
272+
fn, x_jl, ForwardDiff.Chunk{2}(), ForwardDiff.Tag(fn, eltype(x))
273+
)
274+
grad_jl = ForwardDiff.gradient(fn, x_jl, cfg)
275+
276+
@test grad_jl isa JLArray
277+
@test Array(grad_jl) grad
278+
end
279+
258280
end # module

0 commit comments

Comments
 (0)