|
1 | 1 | module ForwardDiffGPUArraysCoreExt
|
2 | 2 |
|
3 | 3 | using GPUArraysCore: AbstractGPUArray
|
4 |
| -using ForwardDiff: ForwardDiff, Dual, Partials |
| 4 | +using ForwardDiff: ForwardDiff, Dual, Partials, npartials, partials |
5 | 5 |
|
6 | 6 | function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x,
|
7 | 7 | seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
|
8 |
| - idxs = ForwardDiff.structural_eachindex(duals, x) |
9 |
| - duals[idxs] .= Dual{T,V,N}.(x[idxs], seed) |
| 8 | + idxs = collect(ForwardDiff.structural_eachindex(duals, x)) |
| 9 | + duals[idxs] .= Dual{T,V,N}.(x[idxs], Ref(seed)) |
10 | 10 | return duals
|
11 | 11 | end
|
12 | 12 |
|
13 | 13 | function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x,
|
14 | 14 | seeds::NTuple{N,Partials{N,V}}) where {T,V,N}
|
15 |
| - idxs = ForwardDiff.structural_eachindex(duals, x) |
| 15 | + idxs = collect(ForwardDiff.structural_eachindex(duals, x))[1:N] |
16 | 16 | duals[idxs] .= Dual{T,V,N}.(x[idxs], seeds[1:N])
|
17 | 17 | return duals
|
18 | 18 | end
|
19 | 19 |
|
20 | 20 | function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x, index,
|
21 | 21 | seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
|
22 | 22 | offset = index - 1
|
23 |
| - idxs = Iterators.drop(ForwardDiff.structural_eachindex(duals, x), offset) |
24 |
| - duals[idxs] .= Dual{T,V,N}.(x[idxs], seed) |
| 23 | + idxs = collect(Iterators.drop(ForwardDiff.structural_eachindex(duals, x), offset)) |
| 24 | + duals[idxs] .= Dual{T,V,N}.(x[idxs], Ref(seed)) |
25 | 25 | return duals
|
26 | 26 | end
|
27 | 27 |
|
28 | 28 | function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x, index,
|
29 | 29 | seeds::NTuple{N,Partials{N,V}}, chunksize = N) where {T,V,N}
|
30 | 30 | offset = index - 1
|
31 |
| - idxs = Iterators.drop(ForwardDiff.structural_eachindex(duals, x), offset) |
| 31 | + idxs = collect( |
| 32 | + Iterators.drop(ForwardDiff.structural_eachindex(duals, x), offset) |
| 33 | + )[1:chunksize] |
32 | 34 | duals[idxs] .= Dual{T,V,N}.(x[idxs], seeds[1:chunksize])
|
33 | 35 | return duals
|
34 | 36 | end
|
35 | 37 |
|
| 38 | +# gradient |
| 39 | +function ForwardDiff.extract_gradient!(::Type{T}, result::AbstractGPUArray, |
| 40 | + dual::Dual) where {T} |
| 41 | + # this closure is needed for gpu compilation |
| 42 | + partial_fn(dual, i) = partials(T, dual, i) |
| 43 | + |
| 44 | + idxs = ForwardDiff.structural_eachindex(result) |
| 45 | + result[idxs] .= partial_fn.(Ref(dual), 1:npartials(dual)) |
| 46 | + return result |
| 47 | +end |
| 48 | + |
| 49 | +function ForwardDiff.extract_gradient_chunk!(::Type{T}, result::AbstractGPUArray, dual, |
| 50 | + index, chunksize) where {T} |
| 51 | + # this closure is needed for gpu compilation |
| 52 | + partial_fn(dual, i) = partials(T, dual, i) |
| 53 | + |
| 54 | + offset = index - 1 |
| 55 | + idxs = collect( |
| 56 | + Iterators.drop(ForwardDiff.structural_eachindex(result), offset) |
| 57 | + )[1:chunksize] |
| 58 | + result[idxs] .= partial_fn.(Ref(dual), 1:chunksize) |
| 59 | + return result |
| 60 | +end |
| 61 | + |
36 | 62 | end
|
0 commit comments