Skip to content

Commit 07d1e73

Browse files
refactor: format
1 parent 01da0dc commit 07d1e73

File tree

3 files changed

+40
-24
lines changed

3 files changed

+40
-24
lines changed

src/parameter_indexing.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,11 @@ function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
5353
end
5454
end
5555

56-
for (t1, t2) in [(ArraySymbolic, Any), (ScalarSymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray})]
56+
for (t1, t2) in [
57+
(ArraySymbolic, Any),
58+
(ScalarSymbolic, Any),
59+
(NotSymbolic, Union{<:Tuple, <:AbstractArray}),
60+
]
5761
@eval function _getp(sys, ::NotSymbolic, ::$t1, p::$t2)
5862
getters = getp.((sys,), p)
5963

@@ -99,7 +103,11 @@ function _setp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
99103
end
100104
end
101105

102-
for (t1, t2) in [(ArraySymbolic, Any), (ScalarSymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray})]
106+
for (t1, t2) in [
107+
(ArraySymbolic, Any),
108+
(ScalarSymbolic, Any),
109+
(NotSymbolic, Union{<:Tuple, <:AbstractArray}),
110+
]
103111
@eval function _setp(sys, ::NotSymbolic, ::$t1, p::$t2)
104112
setters = setp.((sys,), p)
105113
return function setter!(sol, val)

src/state_indexing.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,9 @@ function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
154154
current_time(prob))
155155
end
156156
function _getter2(::Timeseries, prob, i)
157-
return fn(state_values(prob, i), parameter_values(prob), current_time(prob, i))
157+
return fn(state_values(prob, i),
158+
parameter_values(prob),
159+
current_time(prob, i))
158160
end
159161
function _getter2(::NotTimeseries, prob)
160162
return fn(state_values(prob), parameter_values(prob), current_time(prob))
@@ -181,7 +183,11 @@ function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
181183
error("Invalid symbol $sym for `getu`")
182184
end
183185

184-
for (t1, t2) in [(ScalarSymbolic, Any), (ArraySymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray})]
186+
for (t1, t2) in [
187+
(ScalarSymbolic, Any),
188+
(ArraySymbolic, Any),
189+
(NotSymbolic, Union{<:Tuple, <:AbstractArray}),
190+
]
185191
@eval function _getu(sys, ::NotSymbolic, ::$t1, sym::$t2)
186192
getters = getu.((sys,), sym)
187193
_call(getter, args...) = getter(args...)
@@ -252,7 +258,11 @@ function _setu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
252258
error("Invalid symbol $sym for `setu`")
253259
end
254260

255-
for (t1, t2) in [(ScalarSymbolic, Any), (ArraySymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray})]
261+
for (t1, t2) in [
262+
(ScalarSymbolic, Any),
263+
(ArraySymbolic, Any),
264+
(NotSymbolic, Union{<:Tuple, <:AbstractArray}),
265+
]
256266
@eval function _setu(sys, ::NotSymbolic, ::$t1, sym::$t2)
257267
setters = setu.((sys,), sym)
258268
return function setter!(prob, val)

test/state_indexing_test.jl

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@ p = [11.0, 12.0, 13.0]
1818
t = 0.5
1919
fi = FakeIntegrator(sys, copy(u), copy(p), t)
2020
# checking inference for non-concretely typed arrays will always fail
21-
for (sym, val, newval, check_inference) in [
22-
(:x, u[1], 4.0, true)
21+
for (sym, val, newval, check_inference) in [(:x, u[1], 4.0, true)
2322
(:y, u[2], 4.0, true)
2423
(:z, u[3], 4.0, true)
2524
(1, u[1], 4.0, true)
@@ -36,8 +35,7 @@ for (sym, val, newval, check_inference) in [
3635
((:x, [:y, :z]), (u[1], u[2:3]), (4.0, [5.0, 6.0]), true)
3736
((:x, (:y, :z)), (u[1], (u[2], u[3])), (4.0, (5.0, 6.0)), true)
3837
((1, (:y, :z)), (u[1], (u[2], u[3])), (4.0, (5.0, 6.0)), true)
39-
((:x, [:y], (:z,)), (u[1], [u[2]], (u[3],)), (4.0, [5.0], (6.0,)), true)
40-
]
38+
((:x, [:y], (:z,)), (u[1], [u[2]], (u[3],)), (4.0, [5.0], (6.0,)), true)]
4139
get = getu(sys, sym)
4240
set! = setu(sys, sym)
4341
if check_inference
@@ -67,15 +65,13 @@ for (sym, val, newval, check_inference) in [
6765
@test get(u) == val
6866
end
6967

70-
for (sym, oldval, newval, check_inference) in [
71-
(:a, p[1], 4.0, true)
68+
for (sym, oldval, newval, check_inference) in [(:a, p[1], 4.0, true)
7269
(:b, p[2], 5.0, true)
7370
(:c, p[3], 6.0, true)
7471
([:a, :b], p[1:2], [4.0, 5.0], true)
7572
((:c, :b), (p[3], p[2]), (6.0, 5.0), true)
7673
([:x, :a], [u[1], p[1]], [4.0, 5.0], false)
77-
((:y, :b), (u[2], p[2]), (5.0, 6.0), true)
78-
]
74+
((:y, :b), (u[2], p[2]), (5.0, 6.0), true)]
7975
get = getu(fi, sym)
8076
set! = setu(fi, sym)
8177
if check_inference
@@ -126,8 +122,7 @@ xvals = getindex.(sol.u, 1)
126122
yvals = getindex.(sol.u, 2)
127123
zvals = getindex.(sol.u, 3)
128124

129-
for (sym, ans, check_inference) in [
130-
(:x, xvals, true)
125+
for (sym, ans, check_inference) in [(:x, xvals, true)
131126
(:y, yvals, true)
132127
(:z, zvals, true)
133128
(1, xvals, true)
@@ -139,17 +134,22 @@ for (sym, ans, check_inference) in [
139134
([:x, [:y, :z]], vcat.(xvals, [[x] for x in vcat.(yvals, zvals)]), false)
140135
([:x, (:y, :z)], vcat.(xvals, tuple.(yvals, zvals)), false)
141136
([1, (:y, :z)], vcat.(xvals, tuple.(yvals, zvals)), false)
142-
([:x, [:y, :z], (:x, :z)], vcat.(xvals, [[x] for x in vcat.(yvals, zvals)], tuple.(xvals, zvals)), false)
143-
([:x, [:y, 3], (1, :z)], vcat.(xvals, [[x] for x in vcat.(yvals, zvals)], tuple.(xvals, zvals)), false)
137+
([:x, [:y, :z], (:x, :z)],
138+
vcat.(xvals, [[x] for x in vcat.(yvals, zvals)], tuple.(xvals, zvals)),
139+
false)
140+
([:x, [:y, 3], (1, :z)],
141+
vcat.(xvals, [[x] for x in vcat.(yvals, zvals)], tuple.(xvals, zvals)),
142+
false)
144143
((:x, [:y, :z]), tuple.(xvals, vcat.(yvals, zvals)), true)
145144
((:x, (:y, :z)), tuple.(xvals, tuple.(yvals, zvals)), true)
146-
((:x, [:y, :z], (:z, :y)), tuple.(xvals, vcat.(yvals, zvals), tuple.(zvals, yvals)), true)
145+
((:x, [:y, :z], (:z, :y)),
146+
tuple.(xvals, vcat.(yvals, zvals), tuple.(zvals, yvals)),
147+
true)
147148
([:x, :a], vcat.(xvals, p[1]), false)
148149
((:y, :b), tuple.(yvals, p[2]), true)
149150
(:t, t, true)
150151
([:x, :a, :t], vcat.(xvals, p[1], t), false)
151-
((:x, :a, :t), tuple.(xvals, p[1], t), true)
152-
]
152+
((:x, :a, :t), tuple.(xvals, p[1], t), true)]
153153
get = getu(sys, sym)
154154
if check_inference
155155
@inferred get(sol)
@@ -163,13 +163,11 @@ for (sym, ans, check_inference) in [
163163
end
164164
end
165165

166-
for (sym, val) in [
167-
(:a, p[1])
166+
for (sym, val) in [(:a, p[1])
168167
(:b, p[2])
169168
(:c, p[3])
170169
([:a, :b], p[1:2])
171-
((:c, :b), (p[3], p[2]))
172-
]
170+
((:c, :b), (p[3], p[2]))]
173171
get = getu(fi, sym)
174172
@inferred get(fi)
175173
@test get(fi) == val

0 commit comments

Comments
 (0)