Skip to content

Commit 8223e3e

Browse files
Merge pull request #1176 from ChrisRackauckas/fix-formatting
Apply JuliaFormatter to fix formatting CI
2 parents 55012ce + b68b950 commit 8223e3e

File tree

8 files changed

+47
-38
lines changed

8 files changed

+47
-38
lines changed

ext/DiffEqBaseEnzymeExt.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ module DiffEqBaseEnzymeExt
77
import Enzyme: Const
88
using ChainRulesCore
99

10-
function Enzyme.EnzymeRules.augmented_primal(config::Enzyme.EnzymeRules.RevConfigWidth{1},
10+
function Enzyme.EnzymeRules.augmented_primal(
11+
config::Enzyme.EnzymeRules.RevConfigWidth{1},
1112
func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob,
12-
sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}},
13+
sensealg::Union{
14+
Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}},
1315
u0, p, args...; kwargs...) where {RT}
1416
@inline function copy_or_reuse(val, idx)
1517
if Enzyme.EnzymeRules.overwritten(config)[idx] && ismutable(val)
@@ -36,7 +38,8 @@ module DiffEqBaseEnzymeExt
3638

3739
function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.RevConfigWidth{1},
3840
func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, tape, prob,
39-
sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}},
41+
sensealg::Union{
42+
Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}},
4043
u0, p, args...; kwargs...) where {RT}
4144
dres, clos = tape
4245
dres = dres::RT
@@ -53,7 +56,6 @@ module DiffEqBaseEnzymeExt
5356
Enzyme.make_zero!(dres.u)
5457
return ntuple(_ -> nothing, Val(length(args) + 4))
5558
end
56-
5759
end
5860

5961
end

ext/DiffEqBaseTrackerExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ Tracker.@grad function DiffEqBase.solve_up(prob,
103103
},
104104
u0, p, args...;
105105
kwargs...)
106-
sol, pb_f = DiffEqBase._solve_adjoint(
106+
sol,
107+
pb_f = DiffEqBase._solve_adjoint(
107108
prob, sensealg, Tracker.data(u0), Tracker.data(p),
108109
SciMLBase.TrackerOriginator(), args...; kwargs...)
109110

src/callbacks.jl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,9 @@ has_continuous_callback(cb::DiscreteCallback) = false
5050
has_continuous_callback(cb::ContinuousCallback) = true
5151
has_continuous_callback(cb::VectorContinuousCallback) = true
5252
has_continuous_callback(cb::CallbackSet) = !isempty(cb.continuous_callbacks)
53-
has_continuous_callback(cb::Nothing) = false
53+
has_continuous_callback(cb::Nothing) = false#======================================================##======================================================#
5454

55-
#======================================================#
5655
# Callback handling
57-
#======================================================#
5856

5957
function get_tmp(integrator::DEIntegrator, callback)
6058
_tmp = get_tmp_cache(integrator)
@@ -129,14 +127,16 @@ end
129127
AbstractContinuousCallback
130128
}) where {N}
131129
ex = quote
132-
tmin, upcrossing, event_occurred, event_idx = find_callback_time(integrator,
130+
tmin, upcrossing,
131+
event_occurred, event_idx = find_callback_time(integrator,
133132
callbacks[1], 1)
134133
identified_idx = 1
135134
end
136135
for i in 2:N
137136
ex = quote
138137
$ex
139-
tmin2, upcrossing2, event_occurred2, event_idx2 = find_callback_time(
138+
tmin2, upcrossing2,
139+
event_occurred2, event_idx2 = find_callback_time(
140140
integrator,
141141
callbacks[$i],
142142
$i)
@@ -398,7 +398,8 @@ function findall_events!(next_sign, affect!::F1, affect_neg!::F2, prev_sign) whe
398398
end
399399

400400
function find_callback_time(integrator, callback::ContinuousCallback, counter)
401-
event_occurred, interp_index, ts, prev_sign, prev_sign_index, event_idx = determine_event_occurance(
401+
event_occurred, interp_index, ts, prev_sign,
402+
prev_sign_index, event_idx = determine_event_occurance(
402403
integrator,
403404
callback,
404405
counter)
@@ -459,7 +460,8 @@ function find_callback_time(integrator, callback::ContinuousCallback, counter)
459460
end
460461

461462
function find_callback_time(integrator, callback::VectorContinuousCallback, counter)
462-
event_occurred, interp_index, ts, prev_sign, prev_sign_index, event_idx = determine_event_occurance(
463+
event_occurred, interp_index, ts, prev_sign,
464+
prev_sign_index, event_idx = determine_event_occurance(
463465
integrator,
464466
callback,
465467
counter)
@@ -636,7 +638,8 @@ end
636638
@inline function apply_discrete_callback!(integrator, discrete_modified::Bool,
637639
saved_in_cb::Bool, callback::DiscreteCallback,
638640
args...)
639-
bool, saved_in_cb2 = apply_discrete_callback!(integrator,
641+
bool,
642+
saved_in_cb2 = apply_discrete_callback!(integrator,
640643
apply_discrete_callback!(integrator,
641644
callback)...,
642645
args...)

src/solve.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -719,7 +719,7 @@ function build_null_solution(prob::AbstractDEProblem, args...;
719719
save_everystep = true,
720720
save_on = true,
721721
save_start = save_everystep || isempty(saveat) ||
722-
saveat isa Number || prob.tspan[1] in saveat,
722+
saveat isa Number || prob.tspan[1] in saveat,
723723
save_end = true,
724724
kwargs...)
725725
ts = if saveat === ()
@@ -752,7 +752,7 @@ function build_null_solution(
752752
save_everystep = true,
753753
save_on = true,
754754
save_start = save_everystep || isempty(saveat) ||
755-
saveat isa Number || prob.tspan[1] in saveat,
755+
saveat isa Number || prob.tspan[1] in saveat,
756756
save_end = true,
757757
kwargs...)
758758
prob, success = hack_null_solution_init(prob)
@@ -1080,13 +1080,13 @@ function solve(prob::AbstractDEProblem, args...; sensealg = nothing,
10801080
p = p !== nothing ? p : prob.p
10811081

10821082
if wrap isa Val{true}
1083-
wrap_sol(solve_up(prob, sensealg, u0, p, args...;
1084-
originator = set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()),
1085-
kwargs...))
1083+
wrap_sol(solve_up(prob, sensealg, u0, p, args...;
1084+
originator = set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()),
1085+
kwargs...))
10861086
else
1087-
solve_up(prob, sensealg, u0, p, args...;
1088-
originator = set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()),
1089-
kwargs...)
1087+
solve_up(prob, sensealg, u0, p, args...;
1088+
originator = set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()),
1089+
kwargs...)
10901090
end
10911091
end
10921092

test/downstream/complex_number_ad.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ function assert_fun()
6060
p0 = rand(3)
6161
isapprox(loss(p0), loss_via_real(p0); rtol = 1e-4)
6262
end
63-
@assert all([assert_fun() for _ in 1:(2^6)])
63+
@assert all([assert_fun() for _ in 1:(2 ^ 6)])
6464

6565
# test ad with ForwardDiff
6666
function test_ad()
@@ -74,4 +74,4 @@ function test_ad()
7474
isapprox(grad_complex, grad_real; rtol = 1e-6) ? true : (@show rel_err; false)
7575
end
7676

77-
@time @test all([test_ad() for _ in 1:(2^6)])
77+
@time @test all([test_ad() for _ in 1:(2 ^ 6)])

test/downstream/ensemble_ad.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ eu0 = rand(N, 2)
6565
ep = rand(N, 4)
6666

6767
ensemble_prob = EnsembleProblem(prob,
68-
prob_func = (prob, i, repeat) -> remake(prob,
68+
prob_func = (
69+
prob, i, repeat) -> remake(prob,
6970
u0 = eu0[i, :],
7071
p = ep[i, :],
7172
saveat = 0.1))
@@ -75,7 +76,8 @@ cache = Ref{Any}()
7576

7677
function sum_of_e_solution(p)
7778
ensemble_prob = EnsembleProblem(prob,
78-
prob_func = (prob, i, repeat) -> remake(prob,
79+
prob_func = (
80+
prob, i, repeat) -> remake(prob,
7981
u0 = eu0[i, :],
8082
p = p[i, :],
8183
saveat = 0.1))

test/downstream/ensemble_analysis.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,15 @@ m3, m4, c = timepoint_meancov(sim, 0.5, 0.5)
4949
@test v c
5050
m3, m4, c = timepoint_meancor(sim, 0.5, 0.5)
5151
@test c one(c)
52-
m_series = timeseries_point_mean(sim, 0:(1 // 2^(3)):1)
53-
m_series = timeseries_point_median(sim, 0:(1 // 2^(3)):1)
54-
m_series = timeseries_point_quantile(sim, 0.5, 0:(1 // 2^(3)):1)
55-
m_series, v_series = timeseries_point_meanvar(sim, 0:(1 // 2^(3)):1)
56-
summ = EnsembleSummary(sim, 0:(1 // 2^(3)):1)
52+
m_series = timeseries_point_mean(sim, 0:(1 // 2 ^ (3)):1)
53+
m_series = timeseries_point_median(sim, 0:(1 // 2 ^ (3)):1)
54+
m_series = timeseries_point_quantile(sim, 0.5, 0:(1 // 2 ^ (3)):1)
55+
m_series, v_series = timeseries_point_meanvar(sim, 0:(1 // 2 ^ (3)):1)
56+
summ = EnsembleSummary(sim, 0:(1 // 2 ^ (3)):1)
5757
m5, v5 = m_series.u[5], v_series.u[5]
5858
@test m m5
5959
@test v v5
60-
m6, m7, v6 = timeseries_point_meancov(sim, 0:(1 // 2^(3)):1, 0:(1 // 2^(3)):1)[5, 5]
60+
m6, m7, v6 = timeseries_point_meancov(sim, 0:(1 // 2 ^ (3)):1, 0:(1 // 2 ^ (3)):1)[5, 5]
6161
@test m m6
6262
@test m m7
6363
@test v v6
@@ -107,15 +107,15 @@ m3, m4, c = timepoint_meancov(sim, 0.5, 0.5)
107107
@test v c
108108
m3, m4, c = timepoint_meancor(sim, 0.5, 0.5)
109109
@test c ones(size(c)...)
110-
m_series = timeseries_point_mean(sim, 0:(1 // 2^(3)):1)
111-
m_series = timeseries_point_median(sim, 0:(1 // 2^(3)):1)
112-
m_series = timeseries_point_quantile(sim, 0.5, 0:(1 // 2^(3)):1)
113-
m_series, v_series = timeseries_point_meanvar(sim, 0:(1 // 2^(3)):1)
114-
summ = EnsembleSummary(sim, 0:(1 // 2^(3)):1)
110+
m_series = timeseries_point_mean(sim, 0:(1 // 2 ^ (3)):1)
111+
m_series = timeseries_point_median(sim, 0:(1 // 2 ^ (3)):1)
112+
m_series = timeseries_point_quantile(sim, 0.5, 0:(1 // 2 ^ (3)):1)
113+
m_series, v_series = timeseries_point_meanvar(sim, 0:(1 // 2 ^ (3)):1)
114+
summ = EnsembleSummary(sim, 0:(1 // 2 ^ (3)):1)
115115
m5, v5 = m_series.u[5], v_series.u[5]
116116
@test m m5
117117
@test v v5
118-
m6, m7, v6 = timeseries_point_meancov(sim, 0:(1 // 2^(3)):1, 0:(1 // 2^(3)):1)[5, 5]
118+
m6, m7, v6 = timeseries_point_meancov(sim, 0:(1 // 2 ^ (3)):1, 0:(1 // 2 ^ (3)):1)[5, 5]
119119
@test m m6
120120
@test m m7
121121
@test v v6

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ end
7272
@time @safetestset "Unwrapping" include("downstream/unwrapping.jl")
7373
@time @safetestset "Callback BigFloats" include("downstream/bigfloat_events.jl")
7474
@time @safetestset "DE stats" include("downstream/stats_tests.jl")
75-
isempty(VERSION.prerelease) && @time @safetestset "Ensemble AD Tests" include("downstream/ensemble_ad.jl")
75+
isempty(VERSION.prerelease) &&
76+
@time @safetestset "Ensemble AD Tests" include("downstream/ensemble_ad.jl")
7677
@time @safetestset "Community Callback Tests" include("downstream/community_callback_tests.jl")
7778
@time @safetestset "AD via ode with complex numbers" include("downstream/complex_number_ad.jl")
7879
@time @testset "Distributed Ensemble Tests" include("downstream/distributed_ensemble.jl")

0 commit comments

Comments
 (0)