Skip to content

Commit 8406c00

Browse files
test: refactor parameter dependency tests
1 parent 48baee1 commit 8406c00

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

test/parameter_dependencies.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ using NonlinearSolve
2626
continuous_events = [cb1, cb2],
2727
discrete_events = [cb3]
2828
)
29-
@test isequal(only(parameters(sys)), p1)
30-
@test Set(full_parameters(sys)) == Set([p1, p2])
29+
@test !(p2 in Set(parameters(sys)))
30+
@test p2 in Set(full_parameters(sys))
3131
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.5), jac = true)
3232
@test prob.ps[p1] == 1.0
3333
@test prob.ps[p2] == 2.0
@@ -82,8 +82,8 @@ end
8282
parameter_dependencies = [p2 => 2p1]
8383
)
8484
sys = extend(sys2, sys1)
85-
@test isequal(only(parameters(sys)), p1)
86-
@test Set(full_parameters(sys)) == Set([p1, p2])
85+
@test !(p2 in Set(parameters(sys)))
86+
@test p2 in Set(full_parameters(sys))
8787
prob = ODEProblem(complete(sys))
8888
get_dep = getu(prob, 2p2)
8989
@test get_dep(prob) == 4
@@ -259,8 +259,8 @@ end
259259
@named sys = ODESystem(eqs, t)
260260
@named sdesys = SDESystem(sys, noiseeqs; parameter_dependencies ==> 2σ])
261261
sdesys = complete(sdesys)
262-
@test Set(parameters(sdesys)) == Set([σ, β])
263-
@test Set(full_parameters(sdesys)) == Set([σ, β, ρ])
262+
@test !in Set(parameters(sdesys)))
263+
@test ρ in Set(full_parameters(sdesys))
264264

265265
prob = SDEProblem(
266266
sdesys, [x => 1.0, y => 0.0, z => 0.0], (0.0, 100.0), [σ => 10.0, β => 2.33])
@@ -358,17 +358,18 @@ end
358358

359359
ps = prob.p
360360
buffer, repack, _ = canonicalize(Tunable(), ps)
361-
@test only(buffer) == 3.0
362-
buffer[1] = 4.0
361+
idx = parameter_index(sys, p1)
362+
@test buffer[idx.idx] == 3.0
363+
buffer[idx.idx] = 4.0
363364
ps = repack(buffer)
364365
@test getp(sys, p1)(ps) == 4.0
365366
@test getp(sys, p2)(ps) == 8.0
366367

367-
replace!(Tunable(), ps, [1.0])
368+
replace!(Tunable(), ps, ones(length(ps.tunable)))
368369
@test getp(sys, p1)(ps) == 1.0
369370
@test getp(sys, p2)(ps) == 2.0
370371

371-
ps2 = replace(Tunable(), ps, [2.0])
372+
ps2 = replace(Tunable(), ps, 2 .* ps.tunable)
372373
@test getp(sys, p1)(ps2) == 2.0
373374
@test getp(sys, p2)(ps2) == 4.0
374375
end

0 commit comments

Comments
 (0)