@@ -10,6 +10,8 @@ PartialsFn{T}(dual::Dual) where {T} = PartialsFn{T,typeof(dual)}(dual)
10
10
11
11
(f:: PartialsFn{T} )(i) where {T} = partials (T, f. dual, i)
12
12
13
+ _take (itr, N:: Integer ) = Iterators. take (itr, min (length (itr), N))
14
+
13
15
function ForwardDiff. seed! (duals:: AbstractGPUArray{Dual{T,V,N}} , x,
14
16
seed:: Partials{N,V} ) where {T,V,N}
15
17
idxs = collect (ForwardDiff. structural_eachindex (duals, x))
19
21
20
22
function ForwardDiff. seed! (duals:: AbstractGPUArray{Dual{T,V,N}} , x,
21
23
seeds:: NTuple{N,Partials{N,V}} ) where {T,V,N}
22
- idxs = collect (Iterators . take (ForwardDiff. structural_eachindex (duals, x), N))
24
+ idxs = collect (_take (ForwardDiff. structural_eachindex (duals, x), N))
23
25
duals[idxs] .= Dual {T,V,N} .(view (x, idxs), getindex .(Ref (seeds), 1 : length (idxs)))
24
26
return duals
25
27
end
@@ -36,10 +38,7 @@ function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x, index,
36
38
seeds:: NTuple{N,Partials{N,V}} , chunksize) where {T,V,N}
37
39
offset = index - 1
38
40
idxs = collect (
39
- Iterators. take (
40
- Iterators. drop (ForwardDiff. structural_eachindex (duals, x), offset),
41
- chunksize,
42
- ),
41
+ _take (Iterators. drop (ForwardDiff. structural_eachindex (duals, x), offset), chunksize)
43
42
)
44
43
duals[idxs] .= Dual {T,V,N} .(view (x, idxs), getindex .(Ref (seeds), 1 : length (idxs)))
45
44
return duals
49
48
function ForwardDiff. extract_gradient! (:: Type{T} , result:: AbstractGPUArray ,
50
49
dual:: Dual ) where {T}
51
50
fn = PartialsFn {T} (dual)
52
- idxs = collect (Iterators . take (ForwardDiff. structural_eachindex (result), npartials (dual)))
51
+ idxs = collect (_take (ForwardDiff. structural_eachindex (result), npartials (dual)))
53
52
result[idxs] .= fn .(1 : length (idxs))
54
53
return result
55
54
end
@@ -59,10 +58,7 @@ function ForwardDiff.extract_gradient_chunk!(::Type{T}, result::AbstractGPUArray
59
58
fn = PartialsFn {T} (dual)
60
59
offset = index - 1
61
60
idxs = collect (
62
- Iterators. take (
63
- Iterators. drop (ForwardDiff. structural_eachindex (result), offset),
64
- chunksize,
65
- )
61
+ _take (Iterators. drop (ForwardDiff. structural_eachindex (result), offset), chunksize)
66
62
)
67
63
result[idxs] .= fn .(1 : length (idxs))
68
64
return result
0 commit comments