diff --git a/Project.toml b/Project.toml index b53c7c1c..5772a980 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ForwardDiff" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "1.2.2" +version = "1.3.0" [deps] CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950" diff --git a/src/dual.jl b/src/dual.jl index c070644f..f7aeb1d7 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -735,6 +735,57 @@ end return (Dual{T}(sd, cd * π * partials(d)), Dual{T}(cd, -sd * π * partials(d))) end +# LinearAlgebra.givensAlgorithm # +#-------------------------------# + +# This definition ensures that we match `LinearAlgebra.givensAlgorithm` +# for non-dual numbers (i.e., `ForwardDiff.Dual` with zero partials) +# `LinearAlgebra.givensAlgorithm` is derived from LAPACK's dlartg +# which is [documented](https://netlib.org/lapack/explore-html/da/dd3/group__lartg_ga86f8f877eaea0386cdc2c3c175d9ea88.html) to return +# three values c, s, u for two arguments x and y with +# u = sgn(x) sqrt(x^2 + y^2) +# c = x/u +# s = y/u +# The function is discontinuous in u at x=0 +@define_binary_dual_op( + LinearAlgebra.givensAlgorithm, + begin + vx, vy = value(x), value(y) + c, s, u = LinearAlgebra.givensAlgorithm(vx, vy) + ∂c∂x = s^2 / u + ∂c∂y = ∂s∂x = -(c * s / u) + ∂s∂y = c^2 / u + ∂x = partials(x) + ∂y = partials(y) + ∂c = _mul_partials(∂x, ∂y, ∂c∂x, ∂c∂y) + ∂s = _mul_partials(∂x, ∂y, ∂s∂x, ∂s∂y) + ∂u = _mul_partials(∂x, ∂y, c, s) + return Dual{Txy}(c, ∂c), Dual{Txy}(s, ∂s), Dual{Txy}(u, ∂u) + end, + begin + vx = value(x) + c, s, u = LinearAlgebra.givensAlgorithm(vx, y) + ∂c∂x = s^2 / u + ∂s∂x = -(c * s / u) + ∂x = partials(x) + ∂c = ∂c∂x * ∂x + ∂s = ∂s∂x * ∂x + ∂u = c * ∂x + return Dual{Tx}(c, ∂c), Dual{Tx}(s, ∂s), Dual{Tx}(u, ∂u) + end, + begin + vy = value(y) + c, s, u = LinearAlgebra.givensAlgorithm(x, vy) + ∂c∂y = -(c * s / u) + ∂s∂y = c^2 / u + ∂y = partials(y) + ∂c = ∂c∂y * ∂y + ∂s = ∂s∂y * ∂y + ∂u = s * ∂y + return Dual{Ty}(c, ∂c), Dual{Ty}(s, ∂s), Dual{Ty}(u, ∂u) + end, +) + # Symmetric eigvals # #-------------------# diff --git a/test/DerivativeTest.jl b/test/DerivativeTest.jl index ab5a3631..d66a7cc7 100644 --- a/test/DerivativeTest.jl +++ b/test/DerivativeTest.jl @@ -1,6 +1,7 @@ module DerivativeTest import Calculus +import LinearAlgebra import NaNMath using Test @@ -122,4 +123,14 @@ end end end +@testset "Givens rotations: Derivatives" begin + # Test different branches in `LinearAlgebra.givensAlgorithm` + for f in [randexp(), -randexp()], g in [0.0, f / 2, 2f, -f / 2, -2f], i in 1:3 + @test ForwardDiff.derivative(x -> LinearAlgebra.givensAlgorithm(x, g)[i], f) ≈ + Calculus.derivative(x -> LinearAlgebra.givensAlgorithm(x, g)[i], f) + @test ForwardDiff.derivative(x -> LinearAlgebra.givensAlgorithm(f, x)[i], g) ≈ + Calculus.derivative(x -> LinearAlgebra.givensAlgorithm(f, x)[i], g) + end +end + end # module diff --git a/test/DualTest.jl b/test/DualTest.jl index 8ca5d2b1..35d3116b 100644 --- a/test/DualTest.jl +++ b/test/DualTest.jl @@ -10,6 +10,7 @@ using NaNMath, SpecialFunctions, LogExpFunctions using DiffRules import Calculus +import LinearAlgebra struct TestTag end struct OuterTestTag end @@ -685,4 +686,31 @@ end @test ForwardDiff.derivative(x -> sum(1 .+ x .* (0:0.1:1)), 1) == 5.5 end +@testset "Givens rotations: consistency with `LinearAlgebra.givensAlgorithm` for zero partials (no duals)" begin + # Test different branches in `LinearAlgebra.givensAlgorithm` + for f in [randexp(), -randexp()], g in [0.0, f / 2, 2f, -f / 2, -2f] + # Upstream: Result for non-dual numbers + y = LinearAlgebra.givensAlgorithm(f, g) + @test y isa NTuple{3,Float64} + + for n in (1, 2, 5) + zero_tuple = ntuple(Returns(0.0), n) + dual_f = Dual{TestTag}(f, zero_tuple) + dual_g = Dual{TestTag}(g, zero_tuple) + for (_f, _g) in ((dual_f, dual_g), (dual_f, g), (f, dual_g)) + ydual = @inferred(LinearAlgebra.givensAlgorithm(_f, _g)) + @test ydual isa NTuple{3,Dual{TestTag,Float64,n}} + + for (i, yi, yduali) in zip(1:3, y, ydual) + # Primal values must match `LinearAlgebra.givensAlgorithm` with `Float64` inputs + @test ForwardDiff.value(yduali) ≈ yi + + # Partial derivatives must be zero (zero in - zero out) + @test iszero(ForwardDiff.partials(yduali)) + end + end + end + end +end + end # module diff --git a/test/GradientTest.jl b/test/GradientTest.jl index 9008ce8d..bf121239 100644 --- a/test/GradientTest.jl +++ b/test/GradientTest.jl @@ -1,6 +1,7 @@ module GradientTest import Calculus +import LinearAlgebra import NaNMath using Test @@ -330,4 +331,20 @@ end end end +@testset "Givens rotations: Gradients" begin + # Test different branches in `LinearAlgebra.givensAlgorithm` + for f in [randexp(), -randexp()], g in [0.0, f / 2, 2f, -f / 2, -2f], i in 1:3 + # Gradients wrt to a single input argument + dydf = only(ForwardDiff.gradient(x -> LinearAlgebra.givensAlgorithm(only(x), g)[i], [f])) + @test dydf == ForwardDiff.derivative(x -> LinearAlgebra.givensAlgorithm(x, g)[i], f) + dydg = only(ForwardDiff.gradient(x -> LinearAlgebra.givensAlgorithm(f, only(x))[i], [g])) + @test dydg == ForwardDiff.derivative(x -> LinearAlgebra.givensAlgorithm(f, x)[i], g) + + # Gradient with respect to both input arguments + grad = ForwardDiff.gradient(x -> LinearAlgebra.givensAlgorithm(x[1], x[2])[i], [f, g]) + @test grad == [dydf, dydg] + @test grad ≈ Calculus.gradient(x -> LinearAlgebra.givensAlgorithm(x[1], x[2])[i], [f, g]) + end +end + end # module