Skip to content

Commit ce743da

Browse files
avik-paldevmotion
andauthored
fix: apply suggestions from code review
Co-authored-by: David Widmann <[email protected]>
1 parent da2efb7 commit ce743da

File tree

1 file changed

+21
-15
lines changed

1 file changed

+21
-15
lines changed

ext/ForwardDiffGPUArraysCoreExt.jl

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,37 @@ using GPUArraysCore: AbstractGPUArray
44
using ForwardDiff: ForwardDiff, Dual, Partials, npartials, partials
55

66
function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x,
7-
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
7+
seed::Partials{N,V}) where {T,V,N}
88
idxs = collect(ForwardDiff.structural_eachindex(duals, x))
9-
duals[idxs] .= Dual{T,V,N}.(x[idxs], Ref(seed))
9+
duals[idxs] .= Dual{T,V,N}.(view(x, idxs), Ref(seed))
1010
return duals
1111
end
1212

1313
function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x,
1414
seeds::NTuple{N,Partials{N,V}}) where {T,V,N}
15-
idxs = collect(ForwardDiff.structural_eachindex(duals, x))[1:N]
16-
duals[idxs] .= Dual{T,V,N}.(x[idxs], seeds[1:N])
15+
idxs = collect(Iterators.take(ForwardDiff.structural_eachindex(duals, x), N))
16+
duals[idxs] .= Dual{T,V,N}.(view(x, idxs), getindex.(Ref(seeds), 1:length(idxs)))
1717
return duals
1818
end
1919

2020
function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x, index,
21-
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
21+
seed::Partials{N,V}) where {T,V,N}
2222
offset = index - 1
2323
idxs = collect(Iterators.drop(ForwardDiff.structural_eachindex(duals, x), offset))
24-
duals[idxs] .= Dual{T,V,N}.(x[idxs], Ref(seed))
24+
duals[idxs] .= Dual{T,V,N}.(view(x, idxs), Ref(seed))
2525
return duals
2626
end
2727

2828
function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x, index,
29-
seeds::NTuple{N,Partials{N,V}}, chunksize = N) where {T,V,N}
29+
seeds::NTuple{N,Partials{N,V}}, chunksize) where {T,V,N}
3030
offset = index - 1
3131
idxs = collect(
32-
Iterators.drop(ForwardDiff.structural_eachindex(duals, x), offset)
33-
)[1:chunksize]
34-
duals[idxs] .= Dual{T,V,N}.(x[idxs], seeds[1:chunksize])
32+
Iterators.take(
33+
Iterators.drop(ForwardDiff.structural_eachindex(duals, x), offset),
34+
chunksize,
35+
),
36+
)
37+
duals[idxs] .= Dual{T,V,N}.(view(x, idxs), getindex.(Ref(seeds), 1:length(idxs)))
3538
return duals
3639
end
3740

@@ -41,8 +44,8 @@ function ForwardDiff.extract_gradient!(::Type{T}, result::AbstractGPUArray,
4144
# this closure is needed for gpu compilation
4245
partial_fn(dual, i) = partials(T, dual, i)
4346

44-
idxs = ForwardDiff.structural_eachindex(result)
45-
result[idxs] .= partial_fn.(Ref(dual), 1:npartials(dual))
47+
idxs = collect(Iterators.take(ForwardDiff.structural_eachindex(result), npartials(dual)))
48+
result[idxs] .= partial_fn.(Ref(dual), 1:length(idxs))
4649
return result
4750
end
4851

@@ -53,9 +56,12 @@ function ForwardDiff.extract_gradient_chunk!(::Type{T}, result::AbstractGPUArray
5356

5457
offset = index - 1
5558
idxs = collect(
56-
Iterators.drop(ForwardDiff.structural_eachindex(result), offset)
57-
)[1:chunksize]
58-
result[idxs] .= partial_fn.(Ref(dual), 1:chunksize)
59+
Iterators.take(
60+
Iterators.drop(ForwardDiff.structural_eachindex(result), offset),
61+
chunksize,
62+
)
63+
)
64+
result[idxs] .= partial_fn.(Ref(dual), 1:length(idxs))
5965
return result
6066
end
6167

0 commit comments

Comments
 (0)