Skip to content

Commit e643273

Browse files
Merge pull request #1175 from SciML/ChrisRackauckas-patch-3
Check new compats
2 parents 6810581 + 176bbc7 commit e643273

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,14 @@ SciMLSensitivityMooncakeExt = "Mooncake"
5252
ADTypes = "1.9"
5353
Accessors = "0.1.36"
5454
Adapt = "1.0, 2.0, 3.0, 4"
55-
AlgebraicMultigrid = "0.6.0"
55+
AlgebraicMultigrid = "1"
5656
Aqua = "0.8.4"
5757
ArrayInterface = "7"
5858
Calculus = "0.5.1"
5959
ChainRulesCore = "0.10.7, 1"
6060
ComponentArrays = "0.15.5"
6161
DelayDiffEq = "5.43.2"
62-
DiffEqBase = "6.151.1"
62+
DiffEqBase = "6.166.1"
6363
DiffEqCallbacks = "4"
6464
DiffEqNoiseProcess = "5.19"
6565
Distributed = "1"
@@ -92,7 +92,7 @@ RecursiveArrayTools = "3.27.2"
9292
Reexport = "1.0"
9393
ReverseDiff = "1.15.1"
9494
SafeTestsets = "0.1.0"
95-
SciMLBase = "2.51.4"
95+
SciMLBase = "2.79"
9696
SciMLJacobianOperators = "0.1"
9797
SciMLOperators = "0.3"
9898
SciMLStructures = "1.3"

test/sde_neural.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ Random.seed!(238248735)
9999
sensealg = ReverseDiffAdjoint()))
100100
tmp_mean = mean(tmp_sol, dims = 3)[:, :]
101101
tmp_var = var(tmp_sol, dims = 3)[:, :]
102-
sum(abs2, truemean - tmp_mean) + 0.1 * sum(abs2, truevar - tmp_var), tmp_mean
102+
sum(abs2, truemean - tmp_mean) + 0.1 * sum(abs2, truevar - tmp_var)
103103
end
104104

105105
function loss_op(θ)
@@ -112,11 +112,11 @@ Random.seed!(238248735)
112112
sensealg = ReverseDiffAdjoint()))
113113
tmp_mean = mean(tmp_sol, dims = 3)[:, :]
114114
tmp_var = var(tmp_sol, dims = 3)[:, :]
115-
sum(abs2, truemean - tmp_mean) + 0.1 * sum(abs2, truevar - tmp_var), tmp_mean
115+
sum(abs2, truemean - tmp_mean) + 0.1 * sum(abs2, truevar - tmp_var)
116116
end
117117

118118
losses = []
119-
function callback(θ, l, pred)
119+
function callback(state, l)
120120
begin
121121
push!(losses, l)
122122
if length(losses) % 50 == 0
@@ -189,12 +189,12 @@ end
189189
sensealg = BacksolveAdjoint(autojacvec = ReverseDiffVJP()),
190190
saveat = ts, trajectories = 10, abstol = 1e-1, reltol = 1e-1)
191191
A = convert(Array, _sol)
192-
sum(abs2, A .- 1), mean(A)
192+
sum(abs2, A .- 1)
193193
end
194194

195195
# Actually training/fitting the model
196196
losses = []
197-
function callback(θ, l, pred)
197+
function callback(state, l)
198198
begin
199199
push!(losses, l)
200200
if length(losses) % 1 == 0

0 commit comments

Comments
 (0)