@@ -3,6 +3,13 @@ module ForwardDiffGPUArraysCoreExt
3
3
using GPUArraysCore: AbstractGPUArray
4
4
using ForwardDiff: ForwardDiff, Dual, Partials, npartials, partials
5
5
6
+ struct PartialsFn{T,D<: Dual }
7
+ dual:: D
8
+ end
9
+ PartialsFn {T} (dual:: Dual ) where {T} = PartialsFn {T,typeof(dual)} (dual)
10
+
11
+ (f:: PartialsFn{T} )(i) where {T} = partials (T, f. dual, i)
12
+
6
13
function ForwardDiff. seed! (duals:: AbstractGPUArray{Dual{T,V,N}} , x,
7
14
seed:: Partials{N,V} ) where {T,V,N}
8
15
idxs = collect (ForwardDiff. structural_eachindex (duals, x))
41
48
# gradient
42
49
function ForwardDiff. extract_gradient! (:: Type{T} , result:: AbstractGPUArray ,
43
50
dual:: Dual ) where {T}
44
- # this closure is needed for gpu compilation
45
- partial_fn (dual, i) = partials (T, dual, i)
46
-
51
+ fn = PartialsFn {T} (dual)
47
52
idxs = collect (Iterators. take (ForwardDiff. structural_eachindex (result), npartials (dual)))
48
- result[idxs] .= partial_fn .( Ref (dual), 1 : length (idxs))
53
+ result[idxs] .= fn .( 1 : length (idxs))
49
54
return result
50
55
end
51
56
52
57
function ForwardDiff. extract_gradient_chunk! (:: Type{T} , result:: AbstractGPUArray , dual,
53
58
index, chunksize) where {T}
54
- # this closure is needed for gpu compilation
55
- partial_fn (dual, i) = partials (T, dual, i)
56
-
59
+ fn = PartialsFn {T} (dual)
57
60
offset = index - 1
58
61
idxs = collect (
59
62
Iterators. take (
60
63
Iterators. drop (ForwardDiff. structural_eachindex (result), offset),
61
64
chunksize,
62
65
)
63
66
)
64
- result[idxs] .= partial_fn .( Ref (dual), 1 : length (idxs))
67
+ result[idxs] .= fn .( 1 : length (idxs))
65
68
return result
66
69
end
67
70
0 commit comments