@@ -10,8 +10,6 @@ PartialsFn{T}(dual::Dual) where {T} = PartialsFn{T,typeof(dual)}(dual)
10
10
11
11
(f:: PartialsFn{T} )(i) where {T} = partials (T, f. dual, i)
12
12
13
- _take (itr, N:: Integer ) = Iterators. take (itr, min (length (itr), N))
14
-
15
13
function ForwardDiff. seed! (duals:: AbstractGPUArray{Dual{T,V,N}} , x,
16
14
seed:: Partials{N,V} ) where {T,V,N}
17
15
idxs = collect (ForwardDiff. structural_eachindex (duals, x))
21
19
22
20
function ForwardDiff. seed! (duals:: AbstractGPUArray{Dual{T,V,N}} , x,
23
21
seeds:: NTuple{N,Partials{N,V}} ) where {T,V,N}
24
- idxs = collect (_take (ForwardDiff. structural_eachindex (duals, x), N))
22
+ idxs = collect (Iterators . take (ForwardDiff. structural_eachindex (duals, x), N))
25
23
duals[idxs] .= Dual {T,V,N} .(view (x, idxs), getindex .(Ref (seeds), 1 : length (idxs)))
26
24
return duals
27
25
end
@@ -38,7 +36,7 @@ function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x, index,
38
36
seeds:: NTuple{N,Partials{N,V}} , chunksize) where {T,V,N}
39
37
offset = index - 1
40
38
idxs = collect (
41
- _take (Iterators. drop (ForwardDiff. structural_eachindex (duals, x), offset), chunksize)
39
+ Iterators . take (Iterators. drop (ForwardDiff. structural_eachindex (duals, x), offset), chunksize)
42
40
)
43
41
duals[idxs] .= Dual {T,V,N} .(view (x, idxs), getindex .(Ref (seeds), 1 : length (idxs)))
44
42
return duals
48
46
function ForwardDiff. extract_gradient! (:: Type{T} , result:: AbstractGPUArray ,
49
47
dual:: Dual ) where {T}
50
48
fn = PartialsFn {T} (dual)
51
- idxs = collect (_take (ForwardDiff. structural_eachindex (result), npartials (dual)))
49
+ idxs = collect (Iterators . take (ForwardDiff. structural_eachindex (result), npartials (dual)))
52
50
result[idxs] .= fn .(1 : length (idxs))
53
51
return result
54
52
end
@@ -58,7 +56,7 @@ function ForwardDiff.extract_gradient_chunk!(::Type{T}, result::AbstractGPUArray
58
56
fn = PartialsFn {T} (dual)
59
57
offset = index - 1
60
58
idxs = collect (
61
- _take (Iterators. drop (ForwardDiff. structural_eachindex (result), offset), chunksize)
59
+ Iterators . take (Iterators. drop (ForwardDiff. structural_eachindex (result), offset), chunksize)
62
60
)
63
61
result[idxs] .= fn .(1 : length (idxs))
64
62
return result
0 commit comments