|
| 1 | +@testitem "ForwardDiff.jl Integration NonlinearLeastSquaresProblem" tags=[:core] begin |
| 2 | + using ForwardDiff, FiniteDiff, SimpleNonlinearSolve, StaticArrays, LinearAlgebra, |
| 3 | + Zygote, ReverseDiff |
| 4 | + using DifferentiationInterface |
1 | 5 |
|
| 6 | + const DI = DifferentiationInterface |
| 7 | + |
| 8 | + true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4]) |
| 9 | + |
| 10 | + θ_true = [1.0, 0.1, 2.0, 0.5] |
| 11 | + x = [-1.0, -0.5, 0.0, 0.5, 1.0] |
| 12 | + y_target = true_function(x, θ_true) |
| 13 | + |
| 14 | + loss_function(θ, p) = true_function(p, θ) .- y_target |
| 15 | + |
| 16 | + loss_function_jac(θ, p) = ForwardDiff.jacobian(Base.Fix2(loss_function, p), θ) |
| 17 | + |
| 18 | + loss_function_vjp(v, θ, p) = reshape(vec(v)' * loss_function_jac(θ, p), size(θ)) |
| 19 | + |
| 20 | + function loss_function!(resid, θ, p) |
| 21 | + ŷ = true_function(p, θ) |
| 22 | + @. resid = ŷ - y_target |
| 23 | + return |
| 24 | + end |
| 25 | + |
| 26 | + function loss_function_jac!(J, θ, p) |
| 27 | + J .= ForwardDiff.jacobian(θ -> loss_function(θ, p), θ) |
| 28 | + return |
| 29 | + end |
| 30 | + |
| 31 | + function loss_function_vjp!(vJ, v, θ, p) |
| 32 | + vec(vJ) .= reshape(vec(v)' * loss_function_jac(θ, p), size(θ)) |
| 33 | + return |
| 34 | + end |
| 35 | + |
| 36 | + θ_init = θ_true .+ 0.1 |
| 37 | + |
| 38 | + @testset for alg in ( |
| 39 | + SimpleGaussNewton(), |
| 40 | + SimpleGaussNewton(; autodiff = AutoForwardDiff()), |
| 41 | + SimpleGaussNewton(; autodiff = AutoFiniteDiff()), |
| 42 | + SimpleGaussNewton(; autodiff = AutoReverseDiff()) |
| 43 | + ) |
| 44 | + function obj_1(p) |
| 45 | + prob_oop = NonlinearLeastSquaresProblem{false}(loss_function, θ_init, p) |
| 46 | + sol = solve(prob_oop, alg) |
| 47 | + return sum(abs2, sol.u) |
| 48 | + end |
| 49 | + |
| 50 | + function obj_2(p) |
| 51 | + ff = NonlinearFunction{false}( |
| 52 | + loss_function; resid_prototype = zeros(length(y_target))) |
| 53 | + prob_oop = NonlinearLeastSquaresProblem{false}(ff, θ_init, p) |
| 54 | + sol = solve(prob_oop, alg) |
| 55 | + return sum(abs2, sol.u) |
| 56 | + end |
| 57 | + |
| 58 | + function obj_3(p) |
| 59 | + ff = NonlinearFunction{false}(loss_function; vjp = loss_function_vjp) |
| 60 | + prob_oop = NonlinearLeastSquaresProblem{false}(ff, θ_init, p) |
| 61 | + sol = solve(prob_oop, alg) |
| 62 | + return sum(abs2, sol.u) |
| 63 | + end |
| 64 | + |
| 65 | + finitediff = DI.gradient(obj_1, AutoFiniteDiff(), x) |
| 66 | + |
| 67 | + fdiff1 = DI.gradient(obj_1, AutoForwardDiff(), x) |
| 68 | + fdiff2 = DI.gradient(obj_2, AutoForwardDiff(), x) |
| 69 | + fdiff3 = DI.gradient(obj_3, AutoForwardDiff(), x) |
| 70 | + |
| 71 | + @test finitediff≈fdiff1 atol=1e-5 |
| 72 | + @test finitediff≈fdiff2 atol=1e-5 |
| 73 | + @test finitediff≈fdiff3 atol=1e-5 |
| 74 | + @test fdiff1 ≈ fdiff2 ≈ fdiff3 |
| 75 | + |
| 76 | + function obj_4(p) |
| 77 | + prob_iip = NonlinearLeastSquaresProblem( |
| 78 | + NonlinearFunction{true}( |
| 79 | + loss_function!; resid_prototype = zeros(length(y_target))), |
| 80 | + θ_init, |
| 81 | + p) |
| 82 | + sol = solve(prob_iip, alg) |
| 83 | + return sum(abs2, sol.u) |
| 84 | + end |
| 85 | + |
| 86 | + function obj_5(p) |
| 87 | + ff = NonlinearFunction{true}( |
| 88 | + loss_function!; resid_prototype = zeros(length(y_target)), |
| 89 | + jac = loss_function_jac!) |
| 90 | + prob_iip = NonlinearLeastSquaresProblem(ff, θ_init, p) |
| 91 | + sol = solve(prob_iip, alg) |
| 92 | + return sum(abs2, sol.u) |
| 93 | + end |
| 94 | + |
| 95 | + function obj_6(p) |
| 96 | + ff = NonlinearFunction{true}( |
| 97 | + loss_function!; resid_prototype = zeros(length(y_target)), |
| 98 | + vjp = loss_function_vjp!) |
| 99 | + prob_iip = NonlinearLeastSquaresProblem(ff, θ_init, p) |
| 100 | + sol = solve(prob_iip, alg) |
| 101 | + return sum(abs2, sol.u) |
| 102 | + end |
| 103 | + |
| 104 | + finitediff = DI.gradient(obj_4, AutoFiniteDiff(), x) |
| 105 | + |
| 106 | + fdiff4 = DI.gradient(obj_4, AutoForwardDiff(), x) |
| 107 | + fdiff5 = DI.gradient(obj_5, AutoForwardDiff(), x) |
| 108 | + fdiff6 = DI.gradient(obj_6, AutoForwardDiff(), x) |
| 109 | + |
| 110 | + @test finitediff≈fdiff4 atol=1e-5 |
| 111 | + @test finitediff≈fdiff5 atol=1e-5 |
| 112 | + @test finitediff≈fdiff6 atol=1e-5 |
| 113 | + @test fdiff4 ≈ fdiff5 ≈ fdiff6 |
| 114 | + end |
| 115 | +end |
0 commit comments