Skip to content

Commit 4ad8e5d

Browse files
Merge pull request #127 from avik-pal/ap/fast_precompile
Try to improve Compilation Speeds
2 parents e9c8a4c + dea3131 commit 4ad8e5d

19 files changed

+679
-407
lines changed

.github/workflows/CI.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@ on:
66
push:
77
branches:
88
- master
9+
concurrency:
10+
# Skip intermediate builds: always.
11+
# Cancel intermediate builds: only if it is a pull request build.
12+
group: ${{ github.workflow }}-${{ github.ref }}
13+
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
914
jobs:
1015
test:
1116
runs-on: ubuntu-latest

Project.toml

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1313
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1414
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
1515
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
16+
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1617
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
1718
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1819
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
@@ -24,31 +25,37 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
2425

2526
[weakdeps]
2627
ODEInterface = "54ca160b-1b9f-5127-a996-1867f4bc2a2c"
28+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
2729

2830
[extensions]
2931
BoundaryValueDiffEqODEInterfaceExt = "ODEInterface"
32+
BoundaryValueDiffEqOrdinaryDiffEqExt = "OrdinaryDiffEq"
3033

3134
[compat]
3235
ADTypes = "0.2"
3336
Adapt = "3"
37+
Aqua = "0.7"
3438
ArrayInterface = "7"
3539
BandedMatrices = "1"
3640
ConcreteStructs = "0.2"
37-
DiffEqBase = "6.94.2"
41+
DiffEqBase = "6.135"
3842
ForwardDiff = "0.10"
39-
NonlinearSolve = "2"
43+
NonlinearSolve = "2.5"
4044
ODEInterface = "0.5"
45+
OrdinaryDiffEq = "6"
4146
PreallocationTools = "0.4"
47+
PrecompileTools = "1"
4248
RecursiveArrayTools = "2.38.10"
4349
Reexport = "0.2, 1.0"
44-
SciMLBase = "2.2"
50+
SciMLBase = "2.5"
4551
Setfield = "1"
46-
SparseDiffTools = "2.6"
52+
SparseDiffTools = "2.9"
4753
TruncatedStacktraces = "1"
4854
UnPack = "1"
4955
julia = "1.9"
5056

5157
[extras]
58+
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
5259
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
5360
ODEInterface = "54ca160b-1b9f-5127-a996-1867f4bc2a2c"
5461
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
@@ -58,4 +65,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
5865
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5966

6067
[targets]
61-
test = ["StaticArrays", "Random", "DiffEqDevTools", "OrdinaryDiffEq", "Test", "SafeTestsets", "ODEInterface"]
68+
test = ["StaticArrays", "Random", "DiffEqDevTools", "OrdinaryDiffEq", "Test", "SafeTestsets", "ODEInterface", "Aqua"]

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
# BoundaryValueDiffEq
22

3+
[![Join the chat at https://julialang.zulipchat.com #sciml-bridged](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/279055-sciml-bridged)
4+
[![Global Docs](https://img.shields.io/badge/docs-SciML-blue.svg)](https://docs.sciml.ai/BoundaryValueDiffEq/stable/)
5+
36
[![Build Status](https://github.com/SciML/BoundaryValueDiffEq.jl/workflows/CI/badge.svg)](https://github.com/SciML/BoundaryValueDiffEq.jl/actions?query=workflow%3ACI)
47
[![codecov](https://codecov.io/gh/SciML/BoundaryValueDiffEq.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/SciML/BoundaryValueDiffEq.jl)
8+
[![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/BoundaryValueDiffEq)](https://pkgs.genieframework.com?packages=BoundaryValueDiffEq)
9+
[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)
510

611
[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac)
712
[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle)

ext/BoundaryValueDiffEqODEInterfaceExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ end
6666
# BVPSOL
6767
#-------
6868
function __solve(prob::BVProblem, alg::BVPSOL; maxiters = 1000, reltol = 1e-3,
69-
dt = 0.0, verbose = true, kwargs...)
69+
dt = 0.0, verbose = true, kwargs...)
7070
_test_bvpm2_bvpsol_problem_criteria(prob, prob.problem_type, :BVPSOL)
7171
@assert isa(prob.p, SciMLBase.NullParameters) "BVPSOL only supports NullParameters!"
7272
@assert isa(prob.u0, AbstractVector{<:AbstractArray}) "BVPSOL requires a vector of initial guesses!"
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
module BoundaryValueDiffEqOrdinaryDiffEqExt
2+
3+
# This extension doesn't add any new feature atm but is used to precompile some common
4+
# shooting workflows
5+
6+
import PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidations
7+
8+
@recompile_invalidations begin
9+
using BoundaryValueDiffEq, OrdinaryDiffEq
10+
end
11+
12+
@setup_workload begin
13+
function f1!(du, u, p, t)
14+
du[1] = u[2]
15+
du[2] = 0
16+
end
17+
f1(u, p, t) = [u[2], 0]
18+
19+
function bc1!(residual, u, p, t)
20+
residual[1] = u(0.0)[1] - 5
21+
residual[2] = u(5.0)[1]
22+
end
23+
bc1(u, p, t) = [u(0.0)[1] - 5, u(5.0)[1]]
24+
25+
bc1_a!(residual, ua, p) = (residual[1] = ua[1] - 5)
26+
bc1_b!(residual, ub, p) = (residual[1] = ub[1])
27+
28+
bc1_a(ua, p) = [ua[1] - 5]
29+
bc1_b(ub, p) = [ub[1]]
30+
31+
tspan = (0.0, 5.0)
32+
u0 = [5.0, -3.5]
33+
bcresid_prototype = (Array{Float64}(undef, 1), Array{Float64}(undef, 1))
34+
35+
probs = [
36+
BVProblem(f1!, bc1!, u0, tspan),
37+
BVProblem(f1, bc1, u0, tspan),
38+
TwoPointBVProblem(f1!, (bc1_a!, bc1_b!), u0, tspan; bcresid_prototype),
39+
TwoPointBVProblem(f1, (bc1_a, bc1_b), u0, tspan; bcresid_prototype),
40+
]
41+
42+
algs = [
43+
Shooting(Tsit5();
44+
nlsolve = NewtonRaphson(; autodiff = AutoForwardDiff(chunksize = 2))),
45+
MultipleShooting(10,
46+
Tsit5();
47+
nlsolve = NewtonRaphson(; autodiff = AutoForwardDiff(chunksize = 2)),
48+
jac_alg = BVPJacobianAlgorithm(;
49+
bc_diffmode = AutoForwardDiff(; chunksize = 2),
50+
nonbc_diffmode = AutoSparseForwardDiff(; chunksize = 2))),
51+
]
52+
53+
@compile_workload begin
54+
for prob in probs, alg in algs
55+
solve(prob, alg)
56+
end
57+
end
58+
end
59+
60+
end

src/BoundaryValueDiffEq.jl

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,26 @@
11
module BoundaryValueDiffEq
22

3-
using Adapt, BandedMatrices, ForwardDiff, LinearAlgebra, PreallocationTools,
4-
RecursiveArrayTools, Reexport, Setfield, SparseArrays
5-
@reexport using ADTypes, DiffEqBase, NonlinearSolve, SparseDiffTools, SciMLBase
3+
import PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidations
4+
5+
@recompile_invalidations begin
6+
using ADTypes, Adapt, BandedMatrices, DiffEqBase, ForwardDiff, LinearAlgebra,
7+
NonlinearSolve, PreallocationTools, RecursiveArrayTools, Reexport, SciMLBase,
8+
Setfield, SparseArrays, SparseDiffTools
9+
10+
import ADTypes: AbstractADType
11+
import ArrayInterface: matrix_colors,
12+
parameterless_type, undefmatrix, fast_scalar_indexing
13+
import ConcreteStructs: @concrete
14+
import DiffEqBase: solve
15+
import ForwardDiff: pickchunksize
16+
import RecursiveArrayTools: ArrayPartition, DiffEqArray
17+
import SciMLBase: AbstractDiffEqInterpolation, StandardBVProblem, __solve, _unwrap_val
18+
import SparseDiffTools: AbstractSparseADType
19+
import TruncatedStacktraces: @truncate_stacktrace
20+
import UnPack: @unpack
21+
end
622

7-
import ADTypes: AbstractADType
8-
import ArrayInterface: matrix_colors, parameterless_type, undefmatrix, fast_scalar_indexing
9-
import ConcreteStructs: @concrete
10-
import DiffEqBase: solve
11-
import ForwardDiff: pickchunksize
12-
import RecursiveArrayTools: ArrayPartition, DiffEqArray
13-
import SciMLBase: AbstractDiffEqInterpolation, StandardBVProblem, __solve, _unwrap_val
14-
import SparseDiffTools: AbstractSparseADType
15-
import TruncatedStacktraces: @truncate_stacktrace
16-
import UnPack: @unpack
23+
@reexport using ADTypes, DiffEqBase, NonlinearSolve, SparseDiffTools, SciMLBase
1724

1825
include("types.jl")
1926
include("utils.jl")
@@ -37,6 +44,48 @@ function __solve(prob::BVProblem, alg::BoundaryValueDiffEqAlgorithm, args...; kw
3744
return solve!(cache)
3845
end
3946

47+
@setup_workload begin
48+
function f1!(du, u, p, t)
49+
du[1] = u[2]
50+
du[2] = 0
51+
end
52+
f1(u, p, t) = [u[2], 0]
53+
54+
function bc1!(residual, u, p, t)
55+
residual[1] = u[1][1] - 5
56+
residual[2] = u[end][1]
57+
end
58+
bc1(u, p, t) = [u[1][1] - 5, u[end][1]]
59+
60+
bc1_a!(residual, ua, p) = (residual[1] = ua[1] - 5)
61+
bc1_b!(residual, ub, p) = (residual[1] = ub[1])
62+
63+
bc1_a(ua, p) = [ua[1] - 5]
64+
bc1_b(ub, p) = [ub[1]]
65+
66+
tspan = (0.0, 5.0)
67+
u0 = [5.0, -3.5]
68+
bcresid_prototype = (Array{Float64}(undef, 1), Array{Float64}(undef, 1))
69+
70+
probs = [
71+
BVProblem(f1!, bc1!, u0, tspan),
72+
BVProblem(f1, bc1, u0, tspan),
73+
TwoPointBVProblem(f1!, (bc1_a!, bc1_b!), u0, tspan; bcresid_prototype),
74+
TwoPointBVProblem(f1, (bc1_a, bc1_b), u0, tspan; bcresid_prototype),
75+
]
76+
77+
jac_alg = BVPJacobianAlgorithm(AutoForwardDiff(; chunksize = 2))
78+
79+
@compile_workload begin
80+
for prob in probs,
81+
alg in (MIRK2(; jac_alg), MIRK3(; jac_alg), MIRK4(; jac_alg),
82+
MIRK5(; jac_alg), MIRK6(; jac_alg))
83+
84+
solve(prob, alg; dt = 0.2)
85+
end
86+
end
87+
end
88+
4089
export Shooting, MultipleShooting
4190
export MIRK2, MIRK3, MIRK4, MIRK5, MIRK6
4291
export MIRKJacobianComputationAlgorithm, BVPJacobianAlgorithm

src/adaptivity.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ end
8484
Generate a new mesh based on the `ŝ`.
8585
"""
8686
function redistribute!(cache::MIRKCache{iip, T}, Nsub_star, ŝ, mesh,
87-
mesh_dt) where {iip, T}
87+
mesh_dt) where {iip, T}
8888
N = length(mesh)
8989
ζ = sum(ŝ .* mesh_dt) / Nsub_star
9090
k, i = 1, 0
@@ -229,7 +229,7 @@ function sum_stages!(cache::MIRKCache, w, w′, i::Int, dt = cache.mesh_dt[i])
229229
sum_stages!(cache.fᵢ_cache.du, cache.fᵢ₂_cache, cache, w, w′, i, dt)
230230
end
231231

232-
function sum_stages!(z, cache::MIRKCache, w, i::Int, dt = cache.mesh_dt[i])
232+
function sum_stages!(z::AbstractArray, cache::MIRKCache, w, i::Int, dt = cache.mesh_dt[i])
233233
@unpack M, stage, mesh, k_discrete, k_interp, mesh_dt = cache
234234
@unpack s_star = cache.ITU
235235

src/algorithms.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ function update_nshoots(alg::MultipleShooting, nshoots::Int)
4141
end
4242

4343
function MultipleShooting(nshoots::Int, ode_alg; nlsolve = NewtonRaphson(),
44-
grid_coarsening = true, jac_alg = BVPJacobianAlgorithm())
44+
grid_coarsening = true, jac_alg = BVPJacobianAlgorithm())
4545
@assert grid_coarsening isa Bool || grid_coarsening isa Function ||
4646
grid_coarsening isa AbstractVector{<:Integer} ||
4747
grid_coarsening isa NTuple{N, <:Integer} where {N}

src/collocation.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ function Φ!(residual, cache::MIRKCache, y, u, p = cache.p)
44
end
55

66
@views function Φ!(residual, fᵢ_cache, k_discrete, f!, TU::MIRKTableau, y, u, p,
7-
mesh, mesh_dt, stage::Int)
7+
mesh, mesh_dt, stage::Int)
88
@unpack c, v, x, b = TU
99

1010
tmp = get_tmp(fᵢ_cache, u)
@@ -35,7 +35,7 @@ function Φ(cache::MIRKCache, y, u, p = cache.p)
3535
end
3636

3737
@views function Φ(fᵢ_cache, k_discrete, f, TU::MIRKTableau, y, u, p, mesh, mesh_dt,
38-
stage::Int)
38+
stage::Int)
3939
@unpack c, v, x, b = TU
4040
residuals = [similar(yᵢ) for yᵢ in y[1:(end - 1)]]
4141
tmp = get_tmp(fᵢ_cache, u)

src/interpolation.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ end
1919
# FIXME: Fix the interpolation outside the tspan
2020

2121
@inline function interpolation(tvals, id::I, idxs, deriv::D, p,
22-
continuity::Symbol = :left) where {I, D}
22+
continuity::Symbol = :left) where {I, D}
2323
@unpack t, u, cache = id
2424
tdir = sign(t[end] - t[1])
2525
idx = sortperm(tvals, rev = tdir < 0)
@@ -41,7 +41,7 @@ end
4141
end
4242

4343
@inline function interpolation!(vals, tvals, id::I, idxs, deriv::D, p,
44-
continuity::Symbol = :left) where {I, D}
44+
continuity::Symbol = :left) where {I, D}
4545
@unpack t, cache = id
4646
tdir = sign(t[end] - t[1])
4747
idx = sortperm(tvals, rev = tdir < 0)
@@ -54,7 +54,7 @@ end
5454
end
5555

5656
@inline function interpolation(tval::Number, id::I, idxs, deriv::D, p,
57-
continuity::Symbol = :left) where {I, D}
57+
continuity::Symbol = :left) where {I, D}
5858
z = similar(id.cache.fᵢ₂_cache)
5959
interp_eval!(z, id.cache, tval, id.cache.mesh, id.cache.mesh_dt)
6060
return z

0 commit comments

Comments
 (0)