Skip to content

Commit 406a43d

Browse files
committed
fix: use a struct instead of closure
1 parent ce743da commit 406a43d

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

ext/ForwardDiffGPUArraysCoreExt.jl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,13 @@ module ForwardDiffGPUArraysCoreExt
33
using GPUArraysCore: AbstractGPUArray
44
using ForwardDiff: ForwardDiff, Dual, Partials, npartials, partials
55

6+
struct PartialsFn{T,D<:Dual}
7+
dual::D
8+
end
9+
PartialsFn{T}(dual::Dual) where {T} = PartialsFn{T,typeof(dual)}(dual)
10+
11+
(f::PartialsFn{T})(i) where {T} = partials(T, f.dual, i)
12+
613
function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x,
714
seed::Partials{N,V}) where {T,V,N}
815
idxs = collect(ForwardDiff.structural_eachindex(duals, x))
@@ -41,27 +48,23 @@ end
4148
# gradient
4249
function ForwardDiff.extract_gradient!(::Type{T}, result::AbstractGPUArray,
4350
dual::Dual) where {T}
44-
# this closure is needed for gpu compilation
45-
partial_fn(dual, i) = partials(T, dual, i)
46-
51+
fn = PartialsFn{T}(dual)
4752
idxs = collect(Iterators.take(ForwardDiff.structural_eachindex(result), npartials(dual)))
48-
result[idxs] .= partial_fn.(Ref(dual), 1:length(idxs))
53+
result[idxs] .= fn.(1:length(idxs))
4954
return result
5055
end
5156

5257
function ForwardDiff.extract_gradient_chunk!(::Type{T}, result::AbstractGPUArray, dual,
5358
index, chunksize) where {T}
54-
# this closure is needed for gpu compilation
55-
partial_fn(dual, i) = partials(T, dual, i)
56-
59+
fn = PartialsFn{T}(dual)
5760
offset = index - 1
5861
idxs = collect(
5962
Iterators.take(
6063
Iterators.drop(ForwardDiff.structural_eachindex(result), offset),
6164
chunksize,
6265
)
6366
)
64-
result[idxs] .= partial_fn.(Ref(dual), 1:length(idxs))
67+
result[idxs] .= fn.(1:length(idxs))
6568
return result
6669
end
6770

0 commit comments

Comments
 (0)