Skip to content

Commit 8ed31b1

Browse files
authored
CI badges and formatting (#134)
1 parent bfef194 commit 8ed31b1

19 files changed

+253
-210
lines changed

.JuliaFormatter.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
style = "sciml"

.github/dependabot.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
2+
version: 2
3+
updates:
4+
- package-ecosystem: "github-actions"
5+
directory: "/" # Location of package manifests
6+
schedule:
7+
interval: "weekly"

.github/workflows/codequality.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
name: Code Quality Check
2+
3+
on: [pull_request]
4+
5+
jobs:
6+
code-style:
7+
name: Format Suggestions
8+
runs-on: ubuntu-latest
9+
steps:
10+
- uses: julia-actions/julia-format@v3

README.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
1+
# Differentiable programming for Differential equations: a review
2+
3+
[![CI](https://github.com/ODINN-SciML/DiffEqSensitivity-Review/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/ODINN-SciML/DiffEqSensitivity-Review/actions/workflows/CI.yml)
14
![example workflow](https://github.com/ODINN-SciML/DiffEqSensitivity-Review/actions/workflows/latex.yml/badge.svg)
25
![example workflow](https://github.com/ODINN-SciML/DiffEqSensitivity-Review/actions/workflows/biblatex.yml/badge.svg)
3-
[![All Contributors](https://img.shields.io/github/all-contributors/ODINN-SciML/DiffEqSensitivity-Review?color=ee8449&style=flat-square)](#contributors)
46

5-
# Differentiable programming for Differential equations: a review
7+
[![All Contributors](https://img.shields.io/github/all-contributors/ODINN-SciML/DiffEqSensitivity-Review?color=ee8449&style=flat-square)](#contributors)
8+
[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle)
69

710
### ⚠️ New preprint available! 📖 ⚠️
811

912
The review paper is now available as a preprint on arXiv: https://arxiv.org/abs/2406.09699
1013

1114
If you want to cite this work, please use this BibTex citation:
12-
```
15+
16+
```bibtex
1317
@misc{sapienza2024differentiable,
1418
title={Differentiable Programming for Differential Equations: A Review},
1519
author={Facundo Sapienza and Jordi Bolibar and Frank Schäfer and Brian Groenke and Avik Pal and Victor Boussange and Patrick Heimbach and Giles Hooker and Fernando Pérez and Per-Olof Persson and Christopher Rackauckas},

code/DirectMethods/Comparison/direct-comparision.jl

Lines changed: 62 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ abstol = 1e-5
2020
function oscilatior!(du, u, p, t)
2121
ω = p[1]
2222
du[1] = u[2]
23-
du[2] = - ω^2 * u[1]
23+
du[2] = -ω^2 * u[1]
2424
nothing
2525
end
2626

@@ -36,14 +36,14 @@ function solution_derivative(t, u0, p)
3636
ω = p[1]
3737
A₀ = u0[2] / ω
3838
B₀ = u0[1]
39-
return A₀ * ( t * cos(ω * t) - sin(ω * t)/ω ) - B₀ * t * sin(ω * t)
39+
return A₀ * (t * cos(ω * t) - sin(ω * t) / ω) - B₀ * t * sin(ω * t)
4040
end
4141

4242
######### Simple example of how to run the dynamcics ###########
4343

4444
# Solve numerical problem
4545
prob = ODEProblem(oscilatior!, u0, tspan, p)
46-
sol = solve(prob, Tsit5(), reltol=reltol, abstol=abstol)
46+
sol = solve(prob, Tsit5(), reltol = reltol, abstol = abstol)
4747

4848
u_final = sol.u[end][1]
4949

@@ -65,17 +65,17 @@ function finitediff_solver(h, t, u0, p, reltol, abstol)
6565
tspan = (0.0, t)
6666
# Forward model with -h
6767
prob₋ = ODEProblem(oscilatior!, u0, tspan, p₋)
68-
sol₋ = solve(prob₋, Tsit5(), reltol=reltol, abstol=abstol)
68+
sol₋ = solve(prob₋, Tsit5(), reltol = reltol, abstol = abstol)
6969
# Forward model with +h
7070
prob₊ = ODEProblem(oscilatior!, u0, tspan, p₊)
71-
sol₊ = solve(prob₊, Tsit5(), reltol=reltol, abstol=abstol)
71+
sol₊ = solve(prob₊, Tsit5(), reltol = reltol, abstol = abstol)
7272

73-
return (sol₊.u[end][1] - sol₋.u[end][1]) /(2h)
73+
return (sol₊.u[end][1] - sol₋.u[end][1]) / (2h)
7474
end
7575

7676
######### Simulation with differerent stepsizes ###########
7777

78-
stepsizes = 2.0.^collect(round(log2(eps(Float64))):1:0)
78+
stepsizes = 2.0 .^ collect(round(log2(eps(Float64))):1:0)
7979
times = collect(t₀:1.0:t₁)
8080

8181
# True derivative computend analytially
@@ -85,71 +85,85 @@ derivative_true = solution_derivative(t₁, u0, p)
8585
# derivative_numerical = finitediff_numerical.(stepsizes, Ref(t₁), Ref(u0), Ref(p))
8686
derivative_numerical = finitediff_numerical.(stepsizes, Ref(t₁), Ref(u0), Ref(p))
8787
derivative_finitediff_exact = finitediff_numerical.(stepsizes, Ref(t₁), Ref(u0), Ref(p))
88-
error_finitediff_exact = abs.((derivative_numerical .- derivative_true)./derivative_true)
88+
error_finitediff_exact = abs.((derivative_numerical .- derivative_true) ./ derivative_true)
8989

9090
# Finite differences with solution from solver and low tolerance
91-
derivative_solver_low = finitediff_solver.(stepsizes, Ref(t₁), Ref(u0), Ref(p), Ref(1e-6), Ref(1e-6))
92-
error_finitediff_low = abs.((derivative_solver_low .- derivative_true)./derivative_true)
91+
derivative_solver_low = finitediff_solver.(
92+
stepsizes, Ref(t₁), Ref(u0), Ref(p), Ref(1e-6), Ref(1e-6))
93+
error_finitediff_low = abs.((derivative_solver_low .- derivative_true) ./ derivative_true)
9394

9495
# Finite differences with solution from solver and high tolerance
95-
derivative_solver_high = finitediff_solver.(stepsizes, Ref(t₁), Ref(u0), Ref(p), Ref(1e-12), Ref(1e-12))
96-
error_finitediff_high = abs.((derivative_solver_high .- derivative_true)./derivative_true)
96+
derivative_solver_high = finitediff_solver.(
97+
stepsizes, Ref(t₁), Ref(u0), Ref(p), Ref(1e-12), Ref(1e-12))
98+
error_finitediff_high = abs.((derivative_solver_high .- derivative_true) ./ derivative_true)
9799

98100
# Complex step differentiation with solution from solver and high tolerance
99101
u0_complex = ComplexF64.(u0)
100102

101-
derivative_complex_low = complexstep_differentiation.(Ref(x -> solve(ODEProblem(oscilatior!, u0_complex, tspan, [x]), Tsit5(), reltol=1e-6, abstol=1e-6).u[end][1]), Ref(p[1]), stepsizes)
103+
derivative_complex_low = complexstep_differentiation.(
104+
Ref(x -> solve(ODEProblem(oscilatior!, u0_complex, tspan, [x]), Tsit5(), reltol = 1e-6, abstol = 1e-6).u[end][1]),
105+
Ref(p[1]),
106+
stepsizes)
102107
error_complex_low = abs.((derivative_true .- derivative_complex_low) ./ derivative_true)
103-
derivative_complex_high = complexstep_differentiation.(Ref(x -> solve(ODEProblem(oscilatior!, u0_complex, tspan, [x]), Tsit5(), reltol=1e-12, abstol=1e-12).u[end][1]), Ref(p[1]), stepsizes)
108+
derivative_complex_high = complexstep_differentiation.(
109+
Ref(x -> solve(ODEProblem(oscilatior!, u0_complex, tspan, [x]), Tsit5(), reltol = 1e-12, abstol = 1e-12).u[end][1]),
110+
Ref(p[1]),
111+
stepsizes)
104112
error_complex_high = abs.((derivative_true .- derivative_complex_high) ./ derivative_true)
105113

106114
# Complex step Differentiation
107-
derivative_complex_exact = ComplexDiff.derivative.(ω -> solution(t₁, u0, [ω]), p[1], stepsizes)
108-
error_complex_exact = abs.((derivative_complex_exact .- derivative_true)./derivative_true)
115+
derivative_complex_exact = ComplexDiff.derivative.(
116+
ω -> solution(t₁, u0, [ω]), p[1], stepsizes)
117+
error_complex_exact = abs.((derivative_complex_exact .- derivative_true) ./ derivative_true)
109118

110119
# Forward AD applied to numerical solver
111-
derivative_AD_low = Zygote.gradient(p->solve(ODEProblem(oscilatior!, u0, tspan, p), Tsit5(), reltol=1e-6, abstol=1e-6).u[end][1], p)[1][1]
120+
derivative_AD_low = Zygote.gradient(
121+
p -> solve(ODEProblem(oscilatior!, u0, tspan, p), Tsit5(), reltol = 1e-6, abstol = 1e-6).u[end][1],
122+
p)[1][1]
112123
error_AD_low = abs((derivative_true - derivative_AD_low) / derivative_true)
113124

114-
derivative_AD_high = Zygote.gradient(p->solve(ODEProblem(oscilatior!, u0, tspan, p), Tsit5(), reltol=1e-12, abstol=1e-12).u[end][1], p)[1][1]
125+
derivative_AD_high = Zygote.gradient(
126+
p -> solve(ODEProblem(oscilatior!, u0, tspan, p), Tsit5(), reltol = 1e-12, abstol = 1e-12).u[end][1],
127+
p)[1][1]
115128
error_AD_high = abs((derivative_true - derivative_AD_high) / derivative_true)
116129

117-
118130
######### Figure ###########
119131

120-
color_finitediff = RGBf(192/255, 57/255, 43/255)
121-
color_finitediff_low = RGBf(230/255, 126/255, 34/255)
122-
color_complex = RGBf(41/255, 128/255, 185/255)
123-
color_complex_low = RGBf(52/255, 152/255, 219/255)
124-
color_AD = RGBf(142/255, 68/255, 173/255)
125-
color_AD_low = RGBf(155/255, 89/255, 182/255)
132+
color_finitediff = RGBf(192 / 255, 57 / 255, 43 / 255)
133+
color_finitediff_low = RGBf(230 / 255, 126 / 255, 34 / 255)
134+
color_complex = RGBf(41 / 255, 128 / 255, 185 / 255)
135+
color_complex_low = RGBf(52 / 255, 152 / 255, 219 / 255)
136+
color_AD = RGBf(142 / 255, 68 / 255, 173 / 255)
137+
color_AD_low = RGBf(155 / 255, 89 / 255, 182 / 255)
126138

127-
fig = Figure(resolution=(1000, 400))
128-
ax = Axis(fig[1, 1], xlabel = L"Stepsize ($\varepsilon$)", ylabel = L"\text{Absolute relative error}",
129-
xscale = log10, yscale=log10, xlabelsize=24, ylabelsize=24, xticklabelsize=18, yticklabelsize=18)
139+
fig = Figure(resolution = (1000, 400))
140+
ax = Axis(fig[1, 1], xlabel = L"Stepsize ($\varepsilon$)",
141+
ylabel = L"\text{Absolute relative error}", xscale = log10, yscale = log10,
142+
xlabelsize = 24, ylabelsize = 24, xticklabelsize = 18, yticklabelsize = 18)
130143

131144
# Plot derivatived of true solution (no numerical solver)
132-
lines!(ax, stepsizes, error_finitediff_exact, label=L"\text{Finite differences (exact solution)}",
133-
color=color_finitediff, linewidth=2, linestyle = :dash)
145+
lines!(ax, stepsizes, error_finitediff_exact,
146+
label = L"\text{Finite differences (exact solution)}",
147+
color = color_finitediff, linewidth = 2, linestyle = :dash)
134148
lines!(ax, stepsizes, error_complex_exact,
135-
label=L"\text{Complex step differentiation (exact solution)}",
136-
color=color_complex, linewidth=2, linestyle = :dash)
137-
149+
label = L"\text{Complex step differentiation (exact solution)}",
150+
color = color_complex, linewidth = 2, linestyle = :dash)
151+
138152
# Plot derivatives computed on top of numerical solver with finite differences
139-
scatter!(ax, stepsizes, error_finitediff_low,
140-
label=L"Finite differences (tol=$10^{-6}$)", color=color_finitediff_low,
141-
marker ='', markersize=20)
142-
scatter!(ax, stepsizes, error_finitediff_high,
143-
label=L"Finite differences (tol=$10^{-12}$)", color=color_finitediff,
144-
marker ='', markersize=30)
153+
scatter!(
154+
ax, stepsizes, error_finitediff_low, label = L"Finite differences (tol=$10^{-6}$)",
155+
color = color_finitediff_low, marker = '', markersize = 20)
156+
scatter!(
157+
ax, stepsizes, error_finitediff_high, label = L"Finite differences (tol=$10^{-12}$)",
158+
color = color_finitediff, marker = '', markersize = 30)
145159

146160
# Plot derivatives computed on top of numerical solver with complex step method
147161
scatter!(ax, stepsizes, error_complex_low,
148-
label=L"Complex step differentiation (tol=$10^{-6}$)",
149-
color=color_complex_low, marker ='', markersize=20)
162+
label = L"Complex step differentiation (tol=$10^{-6}$)",
163+
color = color_complex_low, marker = '', markersize = 20)
150164
scatter!(ax, stepsizes, error_complex_high,
151-
label=L"Complex step differentiation (tol=$10^{-12}$)", color=color_complex,
152-
marker ='', markersize=30)
165+
label = L"Complex step differentiation (tol=$10^{-12}$)",
166+
color = color_complex, marker = '', markersize = 30)
153167

154168
# AD
155169
# hlines!(ax, [error_AD_low, error_AD_high], color=color_AD, linewidth=1.5)
@@ -158,20 +172,19 @@ scatter!(ax, stepsizes, error_complex_high,
158172
# plot!(ax, [stepsizes[begin], stepsizes[end]],[error_AD_high, error_AD_high],
159173
# color=color_AD, label=L"Forward AD (tol=$10^{-12}$)", marker='•', markersize=25)
160174

161-
lines!(ax, stepsizes, repeat([error_AD_low], length(stepsizes)),
162-
color=color_AD_low, label=L"Forward AD (tol=$10^{-6}$)", linewidth=2)
163-
lines!(ax, stepsizes, repeat([error_AD_high], length(stepsizes)),
164-
color=color_AD, label=L"Forward AD (tol=$10^{-12}$)", linewidth=3)
175+
lines!(ax, stepsizes, repeat([error_AD_low], length(stepsizes)),
176+
color = color_AD_low, label = L"Forward AD (tol=$10^{-6}$)", linewidth = 2)
177+
lines!(ax, stepsizes, repeat([error_AD_high], length(stepsizes)),
178+
color = color_AD, label = L"Forward AD (tol=$10^{-12}$)", linewidth = 3)
165179

166180
# Add legend
167181
fig[1, 2] = Legend(fig, ax)
168182

169183
!ispath("Figures") && mkpath("Figures")
170184
save("Figures/DirectMethods_comparison.pdf", fig)
171185

172-
173186
######### Benchmark ###########
174187

175188
# It looks like complex step has better performance... both in speed and momory allocation.
176189
# @benchmark derivative_complex_low = complexstep_differentiation.(Ref(x -> solve(ODEProblem(oscilatior!, u0_complex, tspan, [x]), Tsit5(), reltol=1e-6, abstol=1e-6).u[end][1]), Ref(p[1]), [1e-5])
177-
# @benchmark derivative_AD_low = Zygote.gradient(p->solve(ODEProblem(oscilatior!, u0, tspan, p), Tsit5(), reltol=1e-6, abstol=1e-6).u[end][1], p)[1][1]
190+
# @benchmark derivative_AD_low = Zygote.gradient(p->solve(ODEProblem(oscilatior!, u0, tspan, p), Tsit5(), reltol=1e-6, abstol=1e-6).u[end][1], p)[1][1]

code/DirectMethods/ComplexStep/complex_solver.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using OrdinaryDiffEq
33
function dyn!(du::Array{Complex{Float64}}, u::Array{Complex{Float64}}, p, t)
44
ω = p[1]
55
du[1] = u[2]
6-
du[2] = - ω^2 * u[1]
6+
du[2] = -ω^2 * u[1]
77
end
88

99
tspan = [0.0, 10.0]
@@ -15,4 +15,5 @@ function complexstep_differentiation(f::Function, p::Float64, ε::Float64)
1515
return imag(f(p_complex)) / ε
1616
end
1717

18-
complexstep_differentiation(x -> solve(ODEProblem(dyn!, u0, tspan, [x]), Tsit5()).u[end][1], 20., 1e-3)
18+
complexstep_differentiation(
19+
x -> solve(ODEProblem(dyn!, u0, tspan, [x]), Tsit5()).u[end][1], 20.0, 1e-3)

code/DirectMethods/DualNumbers/dualnumber_definition.jl

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,38 +11,41 @@ end
1111
# Outer constructor
1212
function DualNumber(value::F) where {F <: AbstractFloat}
1313
DualNumber(value, F(0.0))
14-
end
14+
end
1515

1616
# Chain rules for binary opperators
1717

1818
# Binary sum
19-
Base.:(+)(a::DualNumber, b::DualNumber) = DualNumber(value = a.value + b.value,
20-
derivative = a.derivative + b.derivative)
19+
function Base.:(+)(a::DualNumber, b::DualNumber)
20+
DualNumber(value = a.value + b.value, derivative = a.derivative + b.derivative)
21+
end
2122

2223
# Binary product
23-
Base.:(*)(a::DualNumber, b::DualNumber) = DualNumber(value = a.value * b.value,
24-
derivative = a.value*b.derivative + a.derivative*b.value)
24+
function Base.:(*)(a::DualNumber, b::DualNumber)
25+
DualNumber(value = a.value * b.value,
26+
derivative = a.value * b.derivative + a.derivative * b.value)
27+
end
2528

2629
# Power
27-
Base.:(^)(a::DualNumber, b::AbstractFloat) = DualNumber(value = a.value ^ b,
28-
derivative = b * a.value^(b-1) * a.derivative)
29-
30+
function Base.:(^)(a::DualNumber, b::AbstractFloat)
31+
DualNumber(value = a.value^b, derivative = b * a.value^(b - 1) * a.derivative)
32+
end
3033

3134
# Special functions
3235

3336
function Base.:(sin)(a::DualNumber)
3437
value = sin(a.value)
3538
derivative = a.derivative * cos(a.value)
36-
return DualNumber(value=value, derivative=derivative)
39+
return DualNumber(value = value, derivative = derivative)
3740
end
3841

3942
# Now we define a series of variables. We are interested in computing the derivative with respect to the variable "a":
4043

41-
a = DualNumber(value=1.0, derivative=1.0)
44+
a = DualNumber(value = 1.0, derivative = 1.0)
4245

43-
b = DualNumber(value=2.0, derivative=0.0)
44-
c = DualNumber(value=3.0, derivative=0.0)
46+
b = DualNumber(value = 2.0, derivative = 0.0)
47+
c = DualNumber(value = 3.0, derivative = 0.0)
4548

4649
# Now, we can evaluate a new DualNumber
4750
result = a * b * c
48-
# println("The derivative of a*b*c with respect to a is: ", result.derivative)
51+
# println("The derivative of a*b*c with respect to a is: ", result.derivative)

code/DirectMethods/DualNumbers/dualnumber_tolerances.jl

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,20 @@ This generates solutions u(t) = (t-θ)^5/5 that can be solved exactly with a 5th
1515
"""
1616
function dyn!(du, u, p, t)
1717
θ = p[1]
18-
du .= (t .- θ).^4.0
18+
du .= (t .- θ) .^ 4.0
1919
end
2020

2121
p = [1.0]
2222

2323
prob = ODEProblem(dyn!, u0, tspan, p)
24-
sol = solve(prob, Tsit5(), reltol=reltol, abstol=abstol)
24+
sol = solve(prob, Tsit5(), reltol = reltol, abstol = abstol)
2525

2626
# We can see that the time steps increase with non-stop
2727
# @show diff(sol.t)
2828

2929
function loss(p, sensealg)
3030
prob = ODEProblem(dyn!, u0, tspan, p)
31-
sol = solve(prob, Tsit5(), sensealg=sensealg, reltol=reltol, abstol=abstol)
31+
sol = solve(prob, Tsit5(), sensealg = sensealg, reltol = reltol, abstol = abstol)
3232
@show "Number of time steps: ", length(sol.t)
3333
sol.u[end][1]
3434
end
@@ -52,21 +52,23 @@ condition(u, t, integrator) = true
5252
function printstepsize!(integrator)
5353
if length(integrator.sol.t) > 1
5454
println("Stepsize at step ", length(integrator.sol.t), ": ",
55-
integrator.sol.t[end] - integrator.sol.t[end-1])
55+
integrator.sol.t[end] - integrator.sol.t[end - 1])
5656
end
5757
end
5858

5959
cb = DiscreteCallback(condition, printstepsize!)
6060

6161
# g1 = Zygote.gradient(p -> loss(p, ForwardDiffSensitivity()), internalnorm = (u,t) -> sum(abs2,u/length(u)), p)
62-
g1 = Zygote.gradient(p -> solve(ODEProblem(dyn!, u0, tspan, p),
63-
Tsit5(),
64-
sensealg = ForwardDiffSensitivity(),
65-
saveat = 0.1,
66-
internalnorm = (u,t) -> sum(abs2, u/length(u)),
67-
callback = cb,
68-
reltol=1e-6,
69-
abstol=1e-6).u[end][1], p)
62+
g1 = Zygote.gradient(
63+
p -> solve(ODEProblem(dyn!, u0, tspan, p),
64+
Tsit5(),
65+
sensealg = ForwardDiffSensitivity(),
66+
saveat = 0.1,
67+
internalnorm = (u, t) -> sum(abs2, u / length(u)),
68+
callback = cb,
69+
reltol = 1e-6,
70+
abstol = 1e-6).u[end][1],
71+
p)
7072
@show g1
7173

7274
# Forward Sensitivity
@@ -82,15 +84,15 @@ g1 = Zygote.gradient(p -> solve(ODEProblem(dyn!, u0, tspan, p),
8284

8385
# Corrected AD
8486
# g3 = ForwardDiff.gradient(p -> loss(p, nothing), p)
85-
g3 = Zygote.gradient(p -> solve(ODEProblem(dyn!, u0, tspan, p),
86-
Tsit5(),
87-
sensealg = ForwardDiffSensitivity(),
88-
# saveat = 0.1,
89-
# callback = cb,
90-
reltol=1e-6,
91-
abstol=1e-6).u[end][1], p)
87+
g3 = Zygote.gradient(
88+
p -> solve(ODEProblem(dyn!, u0, tspan, p),
89+
Tsit5(),
90+
sensealg = ForwardDiffSensitivity(), # saveat = 0.1, # callback = cb,
91+
reltol = 1e-6,
92+
abstol = 1e-6).u[end][1],
93+
p)
9294
@show g3
9395

9496
@show grad_true(p)
9597

96-
# Define customized RK(4) solver with given timesteps to show the divergence of forward sensitivities
98+
# Define customized RK(4) solver with given timesteps to show the divergence of forward sensitivities

0 commit comments

Comments
 (0)