Skip to content

Commit c722819

Browse files
committed
fix: revert _take
1 parent 7a8828a commit c722819

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

ext/ForwardDiffGPUArraysCoreExt.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@ PartialsFn{T}(dual::Dual) where {T} = PartialsFn{T,typeof(dual)}(dual)
1010

1111
(f::PartialsFn{T})(i) where {T} = partials(T, f.dual, i)
1212

13-
_take(itr, N::Integer) = Iterators.take(itr, min(length(itr), N))
14-
1513
function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x,
1614
seed::Partials{N,V}) where {T,V,N}
1715
idxs = collect(ForwardDiff.structural_eachindex(duals, x))
@@ -21,7 +19,7 @@ end
2119

2220
function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x,
2321
seeds::NTuple{N,Partials{N,V}}) where {T,V,N}
24-
idxs = collect(_take(ForwardDiff.structural_eachindex(duals, x), N))
22+
idxs = collect(Iterators.take(ForwardDiff.structural_eachindex(duals, x), N))
2523
duals[idxs] .= Dual{T,V,N}.(view(x, idxs), getindex.(Ref(seeds), 1:length(idxs)))
2624
return duals
2725
end
@@ -38,7 +36,7 @@ function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x, index,
3836
seeds::NTuple{N,Partials{N,V}}, chunksize) where {T,V,N}
3937
offset = index - 1
4038
idxs = collect(
41-
_take(Iterators.drop(ForwardDiff.structural_eachindex(duals, x), offset), chunksize)
39+
Iterators.take(Iterators.drop(ForwardDiff.structural_eachindex(duals, x), offset), chunksize)
4240
)
4341
duals[idxs] .= Dual{T,V,N}.(view(x, idxs), getindex.(Ref(seeds), 1:length(idxs)))
4442
return duals
@@ -48,7 +46,7 @@ end
4846
function ForwardDiff.extract_gradient!(::Type{T}, result::AbstractGPUArray,
4947
dual::Dual) where {T}
5048
fn = PartialsFn{T}(dual)
51-
idxs = collect(_take(ForwardDiff.structural_eachindex(result), npartials(dual)))
49+
idxs = collect(Iterators.take(ForwardDiff.structural_eachindex(result), npartials(dual)))
5250
result[idxs] .= fn.(1:length(idxs))
5351
return result
5452
end
@@ -58,7 +56,7 @@ function ForwardDiff.extract_gradient_chunk!(::Type{T}, result::AbstractGPUArray
5856
fn = PartialsFn{T}(dual)
5957
offset = index - 1
6058
idxs = collect(
61-
_take(Iterators.drop(ForwardDiff.structural_eachindex(result), offset), chunksize)
59+
Iterators.take(Iterators.drop(ForwardDiff.structural_eachindex(result), offset), chunksize)
6260
)
6361
result[idxs] .= fn.(1:length(idxs))
6462
return result

0 commit comments

Comments
 (0)