Skip to content

Commit c4c62a4

Browse files
committed
fix: sizecheck
1 parent 03b232e commit c4c62a4

File tree

1 file changed

+6
-10
lines changed

1 file changed

+6
-10
lines changed

ext/ForwardDiffGPUArraysCoreExt.jl

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ PartialsFn{T}(dual::Dual) where {T} = PartialsFn{T,typeof(dual)}(dual)
1010

1111
(f::PartialsFn{T})(i) where {T} = partials(T, f.dual, i)
1212

13+
_take(itr, N::Integer) = Iterators.take(itr, min(length(itr), N))
14+
1315
function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x,
1416
seed::Partials{N,V}) where {T,V,N}
1517
idxs = collect(ForwardDiff.structural_eachindex(duals, x))
@@ -19,7 +21,7 @@ end
1921

2022
function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x,
2123
seeds::NTuple{N,Partials{N,V}}) where {T,V,N}
22-
idxs = collect(Iterators.take(ForwardDiff.structural_eachindex(duals, x), N))
24+
idxs = collect(_take(ForwardDiff.structural_eachindex(duals, x), N))
2325
duals[idxs] .= Dual{T,V,N}.(view(x, idxs), getindex.(Ref(seeds), 1:length(idxs)))
2426
return duals
2527
end
@@ -36,10 +38,7 @@ function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x, index,
3638
seeds::NTuple{N,Partials{N,V}}, chunksize) where {T,V,N}
3739
offset = index - 1
3840
idxs = collect(
39-
Iterators.take(
40-
Iterators.drop(ForwardDiff.structural_eachindex(duals, x), offset),
41-
chunksize,
42-
),
41+
_take(Iterators.drop(ForwardDiff.structural_eachindex(duals, x), offset), chunksize)
4342
)
4443
duals[idxs] .= Dual{T,V,N}.(view(x, idxs), getindex.(Ref(seeds), 1:length(idxs)))
4544
return duals
@@ -49,7 +48,7 @@ end
4948
function ForwardDiff.extract_gradient!(::Type{T}, result::AbstractGPUArray,
5049
dual::Dual) where {T}
5150
fn = PartialsFn{T}(dual)
52-
idxs = collect(Iterators.take(ForwardDiff.structural_eachindex(result), npartials(dual)))
51+
idxs = collect(_take(ForwardDiff.structural_eachindex(result), npartials(dual)))
5352
result[idxs] .= fn.(1:length(idxs))
5453
return result
5554
end
@@ -59,10 +58,7 @@ function ForwardDiff.extract_gradient_chunk!(::Type{T}, result::AbstractGPUArray
5958
fn = PartialsFn{T}(dual)
6059
offset = index - 1
6160
idxs = collect(
62-
Iterators.take(
63-
Iterators.drop(ForwardDiff.structural_eachindex(result), offset),
64-
chunksize,
65-
)
61+
_take(Iterators.drop(ForwardDiff.structural_eachindex(result), offset), chunksize)
6662
)
6763
result[idxs] .= fn.(1:length(idxs))
6864
return result

0 commit comments

Comments
 (0)