Skip to content

Commit ac86cac

Browse files
committed
Fix gradient and Jacobian for functions with Dual output (#770)
* Fix gradient and Jacobian for functions with `Dual` output * Bump version from 1.2.0 to 1.2.1
1 parent 4447f47 commit ac86cac

File tree

7 files changed

+95
-7
lines changed

7 files changed

+95
-7
lines changed

ext/ForwardDiffStaticArraysExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ ForwardDiff._lyap_div!!(A::StaticArrays.MMatrix, λ::AbstractVector) = ForwardDi
4747
result = Expr(:tuple, [:(partials(T, y, $i)) for i in 1:length(x)]...)
4848
return quote
4949
$(Expr(:meta, :inline))
50-
V = StaticArrays.similar_type(S, valtype($y))
50+
V = StaticArrays.similar_type(S, valtype(T, $y))
5151
return V($result)
5252
end
5353
end
@@ -76,13 +76,13 @@ end
7676
result = Expr(:tuple, [:(partials(T, ydual[$i], $j)) for i in 1:M, j in 1:N]...)
7777
return quote
7878
$(Expr(:meta, :inline))
79-
V = StaticArrays.similar_type(S, valtype(eltype($ydual)), Size($M, $N))
79+
V = StaticArrays.similar_type(S, valtype(T, eltype($ydual)), Size($M, $N))
8080
return V($result)
8181
end
8282
end
8383

8484
function extract_jacobian(::Type{T}, ydual::AbstractArray, x::StaticArray) where T
85-
result = similar(ydual, valtype(eltype(ydual)), length(ydual), length(x))
85+
result = similar(ydual, valtype(T, eltype(ydual)), length(ydual), length(x))
8686
return extract_jacobian!(T, result, ydual, length(x))
8787
end
8888

src/dual.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,17 @@ end
128128
@inline valtype(::Dual{T,V,N}) where {T,V,N} = V
129129
@inline valtype(::Type{Dual{T,V,N}}) where {T,V,N} = V
130130

131+
@inline valtype(::Type{T}, ::V) where {T,V} = valtype(T, V)
132+
@inline valtype(::Type, ::Type{V}) where {V} = V
133+
@inline valtype(::Type{T}, ::Type{Dual{T,V,N}}) where {T,V,N} = V
134+
@inline function valtype(::Type{T}, ::Type{Dual{S,V,N}}) where {T,S,V,N}
135+
if S T
136+
Dual{S,V,N}
137+
else
138+
throw(DualMismatchError(T,S))
139+
end
140+
end
141+
131142
@inline tagtype(::V) where {V} = Nothing
132143
@inline tagtype(::Type{V}) where {V} = Nothing
133144
@inline tagtype(::Dual{T,V,N}) where {T,V,N} = T

src/gradient.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ const GRAD_ERROR = DimensionMismatch("gradient(f, x) expects that f(x) is a real
9090
function vector_mode_gradient(f::F, x, cfg::GradientConfig{T}) where {T, F}
9191
ydual = vector_mode_dual_eval!(f, cfg, x)
9292
ydual isa Real || throw(GRAD_ERROR)
93-
result = similar(x, valtype(ydual))
93+
result = similar(x, valtype(T, ydual))
9494
return extract_gradient!(T, result, ydual)
9595
end
9696

@@ -149,7 +149,7 @@ function chunk_mode_gradient_expr(result_definition::Expr)
149149
end
150150

151151
@eval function chunk_mode_gradient(f::F, x, cfg::GradientConfig{T,V,N}) where {F,T,V,N}
152-
$(chunk_mode_gradient_expr(:(result = similar(x, valtype(ydual)))))
152+
$(chunk_mode_gradient_expr(:(result = similar(x, valtype(T, ydual)))))
153153
end
154154

155155
@eval function chunk_mode_gradient!(result, f::F, x, cfg::GradientConfig{T,V,N}) where {F,T,V,N}

src/jacobian.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ function vector_mode_jacobian(f::F, x, cfg::JacobianConfig{T}) where {F,T}
128128
N = chunksize(cfg)
129129
ydual = vector_mode_dual_eval!(f, cfg, x)
130130
ydual isa AbstractArray || throw(JACOBIAN_ERROR)
131-
result = similar(ydual, valtype(eltype(ydual)), length(ydual), N)
131+
result = similar(ydual, valtype(T, eltype(ydual)), length(ydual), N)
132132
extract_jacobian!(T, result, ydual, N)
133133
extract_value!(T, result, ydual)
134134
return result
@@ -217,7 +217,7 @@ end
217217
seed!(xdual, x)
218218
end,
219219
:(ydual = f(xdual)),
220-
:(result = similar(ydual, valtype(eltype(ydual)), length(ydual), xlen)),
220+
:(result = similar(ydual, valtype(T, eltype(ydual)), length(ydual), xlen)),
221221
:()))
222222
end
223223

test/DualTest.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,21 @@ ForwardDiff.:≺(::Type{OuterTestTag}, ::Type{TestTag}) = false
9494
@test ForwardDiff.valtype(NESTED_FDNUM) == Dual{TestTag,V,M}
9595
@test ForwardDiff.valtype(typeof(NESTED_FDNUM)) == Dual{TestTag,V,M}
9696

97+
@test ForwardDiff.valtype(TestTag, FDNUM) == V
98+
@test ForwardDiff.valtype(TestTag, typeof(FDNUM)) == V
99+
@test ForwardDiff.valtype(TestTag, NESTED_FDNUM) == Dual{TestTag,V,M}
100+
@test ForwardDiff.valtype(TestTag, typeof(NESTED_FDNUM)) == Dual{TestTag,V,M}
101+
102+
@test ForwardDiff.valtype(OuterTestTag, FDNUM) == Dual{TestTag,V,N}
103+
@test ForwardDiff.valtype(OuterTestTag, typeof(FDNUM)) == Dual{TestTag,V,N}
104+
@test ForwardDiff.valtype(OuterTestTag, NESTED_FDNUM) == Dual{TestTag,Dual{TestTag,V,M},N}
105+
@test ForwardDiff.valtype(OuterTestTag, typeof(NESTED_FDNUM)) == Dual{TestTag,Dual{TestTag,V,M},N}
106+
107+
@test_throws ForwardDiff.DualMismatchError(TestTag, OuterTestTag) ForwardDiff.valtype(TestTag, Dual{OuterTestTag}(PRIMAL, PARTIALS))
108+
@test_throws ForwardDiff.DualMismatchError(TestTag, OuterTestTag) ForwardDiff.valtype(TestTag, typeof(Dual{OuterTestTag}(PRIMAL, PARTIALS)))
109+
@test_throws ForwardDiff.DualMismatchError(TestTag, OuterTestTag) ForwardDiff.valtype(TestTag, Dual{OuterTestTag}(Dual{TestTag}(PRIMAL, M_PARTIALS), NESTED_PARTIALS))
110+
@test_throws ForwardDiff.DualMismatchError(TestTag, OuterTestTag) ForwardDiff.valtype(TestTag, typeof(Dual{OuterTestTag}(Dual{TestTag}(PRIMAL, M_PARTIALS), NESTED_PARTIALS)))
111+
97112
#####################
98113
# Generic Functions #
99114
#####################

test/GradientTest.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ using DiffTests
1111

1212
include(joinpath(dirname(@__FILE__), "utils.jl"))
1313

14+
struct TestTag end
15+
struct OuterTestTag end
16+
ForwardDiff.:(::Type{TestTag}, ::Type{OuterTestTag}) = true
17+
ForwardDiff.:(::Type{OuterTestTag}, ::Type{<:Tag}) = true
18+
1419
##################
1520
# hardcoded test #
1621
##################
@@ -179,4 +184,30 @@ end
179184
@test_throws DomainError ForwardDiff.gradient(x -> x[1]^x[2], [-1.0, 0.5])
180185
end
181186

187+
# issue #769
188+
@testset "functions with `Dual` output" begin
189+
x = [Dual{OuterTestTag}(Dual{TestTag}(1.3, 2.1), Dual{TestTag}(0.3, -2.4))]
190+
f(x) = sum(ForwardDiff.value, x)
191+
der = ForwardDiff.derivative(ForwardDiff.value, only(x))
192+
193+
# Vector mode
194+
grad = ForwardDiff.gradient(f, x)
195+
@test grad isa Vector{typeof(der)}
196+
@test grad == [der]
197+
grad = ForwardDiff.gradient(f, SVector{1}(x))
198+
@test grad isa SVector{1,typeof(der)}
199+
@test grad == SVector{1}(der)
200+
201+
# Chunk mode
202+
y = repeat(x, 3)
203+
cfg = ForwardDiff.GradientConfig(f, y, ForwardDiff.Chunk{2}())
204+
grad = ForwardDiff.gradient(f, y, cfg)
205+
@test grad isa Vector{typeof(der)}
206+
@test grad == [der, der, der]
207+
cfg = ForwardDiff.GradientConfig(f, SVector{3}(y), ForwardDiff.Chunk{2}())
208+
grad = ForwardDiff.gradient(f, SVector{3}(y), cfg)
209+
@test grad isa SVector{3,typeof(der)}
210+
@test grad == SVector{3}(der, der, der)
211+
end
212+
182213
end # module

test/JacobianTest.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ using LinearAlgebra
1111

1212
include(joinpath(dirname(@__FILE__), "utils.jl"))
1313

14+
struct TestTag end
15+
struct OuterTestTag end
16+
ForwardDiff.:(::Type{TestTag}, ::Type{OuterTestTag}) = true
17+
ForwardDiff.:(::Type{OuterTestTag}, ::Type{<:Tag}) = true
18+
1419
##################
1520
# hardcoded test #
1621
##################
@@ -255,4 +260,30 @@ end
255260
@inferred ForwardDiff.jacobian(g!, [1.0], [0.0])
256261
end
257262

263+
# issue #769
264+
@testset "functions with `Dual` output" begin
265+
x = [Dual{OuterTestTag}(Dual{TestTag}(1.3, 2.1), Dual{TestTag}(0.3, -2.4))]
266+
f(x) = map(ForwardDiff.value, x)
267+
der = ForwardDiff.derivative(ForwardDiff.value, only(x))
268+
269+
# Vector mode
270+
jac = ForwardDiff.jacobian(f, x)
271+
@test jac isa Matrix{typeof(der)}
272+
@test jac == [der;;]
273+
jac = ForwardDiff.jacobian(f, SVector{1}(x))
274+
@test jac isa SMatrix{1,1,typeof(der)}
275+
@test jac == SMatrix{1,1}(der)
276+
277+
# Chunk mode
278+
y = repeat(x, 3)
279+
cfg = ForwardDiff.JacobianConfig(f, y, ForwardDiff.Chunk{2}())
280+
jac = ForwardDiff.jacobian(f, y, cfg)
281+
@test jac isa Matrix{typeof(der)}
282+
@test jac == Diagonal([der, der, der])
283+
cfg = ForwardDiff.JacobianConfig(f, SVector{3}(y), ForwardDiff.Chunk{2}())
284+
jac = ForwardDiff.jacobian(f, SVector{3}(y), cfg)
285+
@test jac isa SMatrix{3,3,typeof(der)}
286+
@test jac == Diagonal([der, der, der])
287+
end
288+
258289
end # module

0 commit comments

Comments
 (0)