@@ -4,34 +4,37 @@ using GPUArraysCore: AbstractGPUArray
4
4
using ForwardDiff: ForwardDiff, Dual, Partials, npartials, partials
5
5
6
6
function ForwardDiff. seed! (duals:: AbstractGPUArray{Dual{T,V,N}} , x,
7
- seed:: Partials{N,V} = zero (Partials{N,V}) ) where {T,V,N}
7
+ seed:: Partials{N,V} ) where {T,V,N}
8
8
idxs = collect (ForwardDiff. structural_eachindex (duals, x))
9
- duals[idxs] .= Dual {T,V,N} .(x[ idxs] , Ref (seed))
9
+ duals[idxs] .= Dual {T,V,N} .(view (x, idxs) , Ref (seed))
10
10
return duals
11
11
end
12
12
13
13
function ForwardDiff. seed! (duals:: AbstractGPUArray{Dual{T,V,N}} , x,
14
14
seeds:: NTuple{N,Partials{N,V}} ) where {T,V,N}
15
- idxs = collect (ForwardDiff. structural_eachindex (duals, x))[ 1 : N]
16
- duals[idxs] .= Dual {T,V,N} .(x[ idxs], seeds[ 1 : N] )
15
+ idxs = collect (Iterators . take ( ForwardDiff. structural_eachindex (duals, x), N))
16
+ duals[idxs] .= Dual {T,V,N} .(view (x, idxs), getindex .( Ref ( seeds), 1 : length (idxs)) )
17
17
return duals
18
18
end
19
19
20
20
function ForwardDiff. seed! (duals:: AbstractGPUArray{Dual{T,V,N}} , x, index,
21
- seed:: Partials{N,V} = zero (Partials{N,V}) ) where {T,V,N}
21
+ seed:: Partials{N,V} ) where {T,V,N}
22
22
offset = index - 1
23
23
idxs = collect (Iterators. drop (ForwardDiff. structural_eachindex (duals, x), offset))
24
- duals[idxs] .= Dual {T,V,N} .(x[ idxs] , Ref (seed))
24
+ duals[idxs] .= Dual {T,V,N} .(view (x, idxs) , Ref (seed))
25
25
return duals
26
26
end
27
27
28
28
function ForwardDiff. seed! (duals:: AbstractGPUArray{Dual{T,V,N}} , x, index,
29
- seeds:: NTuple{N,Partials{N,V}} , chunksize = N ) where {T,V,N}
29
+ seeds:: NTuple{N,Partials{N,V}} , chunksize) where {T,V,N}
30
30
offset = index - 1
31
31
idxs = collect (
32
- Iterators. drop (ForwardDiff. structural_eachindex (duals, x), offset)
33
- )[1 : chunksize]
34
- duals[idxs] .= Dual {T,V,N} .(x[idxs], seeds[1 : chunksize])
32
+ Iterators. take (
33
+ Iterators. drop (ForwardDiff. structural_eachindex (duals, x), offset),
34
+ chunksize,
35
+ ),
36
+ )
37
+ duals[idxs] .= Dual {T,V,N} .(view (x, idxs), getindex .(Ref (seeds), 1 : length (idxs)))
35
38
return duals
36
39
end
37
40
@@ -41,8 +44,8 @@ function ForwardDiff.extract_gradient!(::Type{T}, result::AbstractGPUArray,
41
44
# this closure is needed for gpu compilation
42
45
partial_fn (dual, i) = partials (T, dual, i)
43
46
44
- idxs = ForwardDiff. structural_eachindex (result)
45
- result[idxs] .= partial_fn .(Ref (dual), 1 : npartials (dual ))
47
+ idxs = collect (Iterators . take ( ForwardDiff. structural_eachindex (result), npartials (dual)) )
48
+ result[idxs] .= partial_fn .(Ref (dual), 1 : length (idxs ))
46
49
return result
47
50
end
48
51
@@ -53,9 +56,12 @@ function ForwardDiff.extract_gradient_chunk!(::Type{T}, result::AbstractGPUArray
53
56
54
57
offset = index - 1
55
58
idxs = collect (
56
- Iterators. drop (ForwardDiff. structural_eachindex (result), offset)
57
- )[1 : chunksize]
58
- result[idxs] .= partial_fn .(Ref (dual), 1 : chunksize)
59
+ Iterators. take (
60
+ Iterators. drop (ForwardDiff. structural_eachindex (result), offset),
61
+ chunksize,
62
+ )
63
+ )
64
+ result[idxs] .= partial_fn .(Ref (dual), 1 : length (idxs))
59
65
return result
60
66
end
61
67
0 commit comments