Skip to content

Commit f230880

Browse files
committed
fix: regression in non-fast scalar indexing support
1 parent d90a505 commit f230880

File tree

3 files changed

+56
-2
lines changed

3 files changed

+56
-2
lines changed

Project.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ForwardDiff"
22
uuid = "f6369f11-7733-5829-9624-2563aa707210"
3-
version = "1.0.1"
3+
version = "1.0.2"
44

55
[deps]
66
CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950"
@@ -16,9 +16,11 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1616
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1717

1818
[weakdeps]
19+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1920
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2021

2122
[extensions]
23+
ForwardDiffGPUArraysCoreExt = "GPUArraysCore"
2224
ForwardDiffStaticArraysExt = "StaticArrays"
2325

2426
[compat]
@@ -27,6 +29,7 @@ CommonSubexpressions = "0.3"
2729
DiffResults = "1.1"
2830
DiffRules = "1.4"
2931
DiffTests = "0.1"
32+
GPUArraysCore = "0.2"
3033
IrrationalConstants = "0.1, 0.2"
3134
LogExpFunctions = "0.3"
3235
NaNMath = "1"
@@ -40,9 +43,10 @@ Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
4043
DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d"
4144
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
4245
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
46+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
4347
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
4448
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
4549
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4650

4751
[targets]
48-
test = ["Calculus", "DiffTests", "IrrationalConstants", "SparseArrays", "StaticArrays", "Test", "InteractiveUtils"]
52+
test = ["Calculus", "DiffTests", "IrrationalConstants", "SparseArrays", "StaticArrays", "Test", "InteractiveUtils", "JLArrays"]

ext/ForwardDiffGPUArraysCoreExt.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
module ForwardDiffGPUArraysCoreExt
2+
3+
using GPUArraysCore: AbstractGPUArray
4+
using ForwardDiff: ForwardDiff, Dual, Partials
5+
6+
function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x,
7+
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
8+
idxs = ForwardDiff.structural_eachindex(duals, x)
9+
duals[idxs] .= Dual{T,V,N}.(x[idxs], seed)
10+
return duals
11+
end
12+
13+
function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x,
14+
seeds::NTuple{N,Partials{N,V}}) where {T,V,N}
15+
idxs = ForwardDiff.structural_eachindex(duals, x)
16+
duals[idxs] .= Dual{T,V,N}.(x[idxs], seeds[1:N])
17+
return duals
18+
end
19+
20+
function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x, index,
21+
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
22+
offset = index - 1
23+
idxs = Iterators.drop(ForwardDiff.structural_eachindex(duals, x), offset)
24+
duals[idxs] .= Dual{T,V,N}.(x[idxs], seed)
25+
return duals
26+
end
27+
28+
function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x, index,
29+
seeds::NTuple{N,Partials{N,V}}, chunksize = N) where {T,V,N}
30+
offset = index - 1
31+
idxs = Iterators.drop(ForwardDiff.structural_eachindex(duals, x), offset)
32+
duals[idxs] .= Dual{T,V,N}.(x[idxs], seeds[1:chunksize])
33+
return duals
34+
end
35+
36+
end

test/JacobianTest.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using ForwardDiff: Dual, Tag, JacobianConfig
88
using StaticArrays
99
using DiffTests
1010
using LinearAlgebra
11+
using JLArrays
1112

1213
include(joinpath(dirname(@__FILE__), "utils.jl"))
1314

@@ -279,4 +280,17 @@ end
279280
end
280281
end
281282

283+
@testset "GPUArraysCore" begin
284+
f(x) = x .^ 2 ./ 2
285+
286+
x = [1.0, 2.0, 3.0]
287+
x_jl = JLArray(x)
288+
289+
jac = ForwardDiff.jacobian(f, x)
290+
jac_jl = ForwardDiff.jacobian(f, x_jl)
291+
292+
@test jac_jl isa JLArray
293+
@test Array(jac_jl) jac
294+
end
295+
282296
end # module

0 commit comments

Comments
 (0)