Skip to content

Commit ab44889

Browse files
Merge pull request #853 from AayushSabharwal/as/getsetsym
refactor: use `getsym`/`setsym` over `getu`/`setu`
2 parents 51261d1 + 909f93c commit ab44889

11 files changed

+71
-71
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ StableRNGs = "1.0"
8888
StaticArrays = "1.7"
8989
StaticArraysCore = "1.4"
9090
Statistics = "1.10"
91-
SymbolicIndexingInterface = "0.3.31"
91+
SymbolicIndexingInterface = "0.3.34"
9292
Tables = "1.11"
9393
Zygote = "0.6.67"
9494
julia = "1.10"

src/integrator_interface.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, sym)
492492
if is_parameter(A, sym)
493493
error("Indexing with parameters is deprecated. Use `integrator.ps[$sym]` for parameter indexing.")
494494
end
495-
return getu(A, sym)(A)
495+
return getsym(A, sym)(A)
496496
end
497497

498498
Base.@propagate_inbounds function Base.getindex(
@@ -501,7 +501,7 @@ Base.@propagate_inbounds function Base.getindex(
501501
is_parameter(A, sym)
502502
error("Indexing with parameters is deprecated. Use `integrator.ps[$sym]` for parameter indexing.")
503503
end
504-
return getu(A, sym)(A)
504+
return getsym(A, sym)(A)
505505
end
506506

507507
Base.@propagate_inbounds function Base.getindex(
@@ -522,15 +522,15 @@ function Base.setindex!(A::DEIntegrator, val, sym)
522522
if is_parameter(A, sym)
523523
error("Indexing with parameters is deprecated. Use `integrator.ps[$sym] = $val` for parameter indexing.")
524524
end
525-
setu(A, sym)(A, val)
525+
setsym(A, sym)(A, val)
526526
end
527527

528528
function Base.setindex!(A::DEIntegrator, val, sym::Union{AbstractArray, Tuple})
529529
if symbolic_type(sym) == NotSymbolic() && any(x -> is_parameter(A, x), sym) ||
530530
is_parameter(A, sym)
531531
error("Indexing with parameters is deprecated. Use `integrator.ps[$sym] = $val` for parameter indexing.")
532532
end
533-
setu(A, sym)(A, val)
533+
setsym(A, sym)(A, val)
534534
end
535535

536536
### Integrator traits

src/problems/problem_interface.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractSciMLProblem, sym)
4242
if is_parameter(A, sym)
4343
error("Indexing with parameters is deprecated. Use `prob.ps[$sym]` for parameter indexing.")
4444
end
45-
return getu(A, sym)(A)
45+
return getsym(A, sym)(A)
4646
end
4747

4848
Base.@propagate_inbounds function Base.getindex(
@@ -51,7 +51,7 @@ Base.@propagate_inbounds function Base.getindex(
5151
is_parameter(A, sym)
5252
error("Indexing with parameters is deprecated. Use `prob.ps[$sym]` for parameter indexing.")
5353
end
54-
return getu(A, sym)(A)
54+
return getsym(A, sym)(A)
5555
end
5656

5757
function Base.setindex!(prob::AbstractSciMLProblem, args...; kwargs...)
@@ -62,7 +62,7 @@ function ___internal_setindex!(A::AbstractSciMLProblem, val, sym)
6262
if is_parameter(A, sym)
6363
error("Indexing with parameters is deprecated. Use `prob.ps[$sym] = $val` for parameter indexing.")
6464
end
65-
return setu(A, sym)(A, val)
65+
return setsym(A, sym)(A, val)
6666
end
6767

6868
function ___internal_setindex!(
@@ -71,5 +71,5 @@ function ___internal_setindex!(
7171
is_parameter(A, sym)
7272
error("Indexing with parameters is deprecated. Use `prob.ps[$sym] = $val` for parameter indexing.")
7373
end
74-
return setu(A, sym)(A, val)
74+
return setsym(A, sym)(A, val)
7575
end

src/remake.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ function fill_u0(prob, u0; defs = nothing, use_defaults = false)
593593
(symbolic_type(defval) != NotSymbolic() || use_defaults)
594594
defval
595595
else
596-
getu(prob, sym)(prob)
596+
getsym(prob, sym)(prob)
597597
end
598598
end
599599
newvals = anydict()
@@ -671,7 +671,7 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false}, t0)
671671
# used, since any state symbols in the expression were substituted out earlier.
672672
temp_state = ProblemState(; u = state_values(prob), p = p, t = t0)
673673
for (k, v) in u0
674-
u0[k] = symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state)
674+
u0[k] = symbolic_type(v) === NotSymbolic() ? v : getsym(prob, v)(temp_state)
675675
end
676676
return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p
677677
end
@@ -692,7 +692,7 @@ function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true}, t0)
692692
# used, since any parameter symbols in the expression were substituted out earlier.
693693
temp_state = ProblemState(; u = u0, p = parameter_values(prob), t = t0)
694694
for (k, v) in p
695-
p[k] = symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state)
695+
p[k] = symbolic_type(v) === NotSymbolic() ? v : getsym(prob, v)(temp_state)
696696
end
697697
return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p))
698698
end

src/solutions/ode_solutions.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs,
296296
end
297297
end
298298
state = ProblemState(; u = sol.interp(t, nothing, deriv, ps, continuity), p = ps, t)
299-
return getu(sol, idxs)(state)
299+
return getsym(sol, idxs)(state)
300300
end
301301

302302
function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVector,
@@ -321,15 +321,15 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVect
321321
end
322322
end
323323
state = ProblemState(; u = sol.interp(t, nothing, deriv, ps, continuity), p = ps, t)
324-
return getu(sol, idxs)(state)
324+
return getsym(sol, idxs)(state)
325325
end
326326

327327
function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs,
328328
continuity) where {deriv}
329329
symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`")
330330
error_if_observed_derivative(sol, idxs, deriv)
331331
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
332-
getter = getu(sol, idxs)
332+
getter = getsym(sol, idxs)
333333
if is_parameter_timeseries(sol) == NotTimeseries() || !is_discrete_expression(sol, idxs)
334334
interp_sol = augment(sol.interp(t, nothing, deriv, p, continuity), sol)
335335
return DiffEqArray(getter(interp_sol), t, p, sol)
@@ -353,7 +353,7 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
353353
end
354354
error_if_observed_derivative(sol, idxs, deriv)
355355
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
356-
getter = getu(sol, idxs)
356+
getter = getsym(sol, idxs)
357357
if is_parameter_timeseries(sol) == NotTimeseries() || !is_discrete_expression(sol, idxs)
358358
interp_sol = augment(sol.interp(t, nothing, deriv, p, continuity), sol)
359359
return DiffEqArray(getter(interp_sol), t, p, sol)

src/solutions/solution_interface.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, sym)
114114
if is_parameter(A, sym)
115115
error("Indexing with parameters is deprecated. Use `sol.ps[$sym]` for parameter indexing.")
116116
end
117-
return getu(A, sym)(A)
117+
return getsym(A, sym)(A)
118118
end
119119

120120
Base.@propagate_inbounds function Base.getindex(
@@ -123,7 +123,7 @@ Base.@propagate_inbounds function Base.getindex(
123123
is_parameter(A, sym)
124124
error("Indexing with parameters is deprecated. Use `sol.ps[$sym]` for parameter indexing.")
125125
end
126-
return getu(A, sym)(A)
126+
return getsym(A, sym)(A)
127127
end
128128

129129
Base.@propagate_inbounds function Base.getindex(
@@ -359,7 +359,7 @@ plottable_indices(x::Number) = 1
359359
xvar = only(independent_variable_symbols(sol))
360360
end
361361
xvals = sol(ts; idxs = xvar).u
362-
# xvals = getu(sol, xvar)(sol, tstart:tend)
362+
# xvals = getsym(sol, xvar)(sol, tstart:tend)
363363
yvals = getp(sol, yvar)(sol, tstart:tend)
364364
tmpvals = map(func, xvals, yvals)
365365
xvals = getindex.(tmpvals, 1)

test/downstream/adjoints.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ sys2 = complete(sys2)
8080
prob2 = ODEProblem(sys2, [], (0.0, 10.0))
8181

8282
bi = BatchedInterface((sys1, [x, y, z]), (sys2, [x, y, w]))
83-
getter = getu(bi)
83+
getter = getsym(bi)
8484

8585
p1grad, p2grad = Zygote.gradient(prob1, prob2) do prob1, prob2
8686
sum(getter(prob1, prob2))

test/downstream/comprehensive_indexing.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,8 @@ timeseries_systems = [osys, ssys, jsys]
119119
((indp.X, indp.Y), Tuple(u[uidxs]), (4.0, 4.0))
120120
((:X, :Y), Tuple(u[uidxs]), (4.0, 4.0))
121121
(Tuple(uidxs), Tuple(u[uidxs]), (4.0, 4.0))]
122-
get = getu(indp, sym)
123-
set! = setu(indp, sym)
122+
get = getsym(indp, sym)
123+
set! = setsym(indp, sym)
124124
@inferred get(valp)
125125
@test get(valp) == val
126126
if valp isa JumpProblem && sym isa Union{Tuple, AbstractArray}
@@ -153,12 +153,12 @@ timeseries_systems = [osys, ssys, jsys]
153153
([X, indp.Y, :XY, X * Y], [u[uidxs]..., sum(u), prod(u)])
154154
((X, indp.Y, :XY, X * Y), (u[uidxs]..., sum(u), prod(u)))
155155
(X * Y, prod(u))]
156-
get = getu(indp, sym)
156+
get = getsym(indp, sym)
157157
@test get(valp) == val
158158
end
159159
end
160160

161-
getter = getu(indp, [])
161+
getter = getsym(indp, [])
162162
@test getter(valp) == []
163163

164164
p = getindex.((Dict(p_vals),), [kp, kd, k1, k2])
@@ -264,7 +264,7 @@ end
264264
true)
265265
(X * Y, xvals .* yvals,
266266
false, true)]
267-
get = getu(indp, sym)
267+
get = getsym(indp, sym)
268268
if check_inference
269269
@inferred get(valp)
270270
end
@@ -416,9 +416,9 @@ end
416416
[[i, [k, j], (k, j)] for (i, j, k) in zip(x_val, y_val, obs_val)], false)
417417
]
418418
if check_inference
419-
@inferred getu(prob, sym)(sol)
419+
@inferred getsym(prob, sym)(sol)
420420
end
421-
@test getu(prob, sym)(sol) == val
421+
@test getsym(prob, sym)(sol) == val
422422
end
423423
end
424424

@@ -454,8 +454,8 @@ end
454454
((x[1:2], (y_idx, x[3])), (x_probval[1:2], (y_probval, x_probval[3])),
455455
(x_newval[1:2], (y_newval, x_newval[3])), false)
456456
]
457-
getter = getu(prob, sym)
458-
setter! = setu(prob, sym)
457+
getter = getsym(prob, sym)
458+
setter! = setsym(prob, sym)
459459
if check_inference
460460
@inferred getter(prob)
461461
end
@@ -818,7 +818,7 @@ end
818818
([kp, 2ud1], true, vcat.(kpval, 2 .* ud1val), false),
819819
((kp, 2ud1), true, tuple.(kpval, 2 .* ud1val), false)
820820
]
821-
getter = getu(sys, sym)
821+
getter = getsym(sys, sym)
822822
if check_inference
823823
@inferred getter(sol)
824824
end
@@ -853,7 +853,7 @@ end
853853
([2x, 3xd1], [2_xval, 3_xd1val], true),
854854
((2x, 3xd2), (2_xval, 3_xd2val), true)
855855
]
856-
getter = getu(sys, sym)
856+
getter = getsym(sys, sym)
857857
@test_throws Exception getter(sol)
858858
for subidx in [1, CartesianIndex(1), :, rand(Bool, 4), rand(1:4, 3), 1:2]
859859
@test_throws Exception getter(sol, subidx)

test/downstream/integrator_indexing.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -330,21 +330,21 @@ integrator = init(prob, Tsit5(), save_everystep = false)
330330
@test integrator[x] isa Vector{Float64}
331331
@test integrator[@nonamespace sys.x] isa Vector{Float64}
332332

333-
getx = getu(integrator, x)
334-
gety = getu(integrator, :y)
335-
get_arr = getu(integrator, [x, y])
336-
get_tuple = getu(integrator, (x, y))
337-
get_obs = getu(integrator, x[1] / p[1])
333+
getx = getsym(integrator, x)
334+
gety = getsym(integrator, :y)
335+
get_arr = getsym(integrator, [x, y])
336+
get_tuple = getsym(integrator, (x, y))
337+
get_obs = getsym(integrator, x[1] / p[1])
338338
@test getx(integrator) == [1.0, 2.0, 3.0]
339339
@test gety(integrator) == 1.0
340340
@test get_arr(integrator) == [[1.0, 2.0, 3.0], 1.0]
341341
@test get_tuple(integrator) == ([1.0, 2.0, 3.0], 1.0)
342342
@test get_obs(integrator) == 1.0
343343

344-
setx! = setu(integrator, x)
345-
sety! = setu(integrator, :y)
346-
set_arr! = setu(integrator, [x, y])
347-
set_tuple! = setu(integrator, (x, y))
344+
setx! = setsym(integrator, x)
345+
sety! = setsym(integrator, :y)
346+
set_arr! = setsym(integrator, [x, y])
347+
set_tuple! = setsym(integrator, (x, y))
348348

349349
setx!(integrator, [4.0, 5.0, 6.0])
350350
@test getx(integrator) == [4.0, 5.0, 6.0]

test/downstream/modelingtoolkit_remake.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ for (sys, prob) in zip(syss, probs)
7979
@inferred typeof(prob) remake(prob)
8080

8181
baseType = Base.typename(typeof(prob)).wrapper
82-
ugetter = getu(prob, [x, y, z])
82+
ugetter = getsym(prob, [x, y, z])
8383
prob2 = @inferred baseType remake(prob; u0 = [x => 2.0, y => 3.0, z => 4.0])
8484
@test ugetter(prob2) == [2.0, 3.0, 4.0]
8585
prob2 = @inferred baseType remake(prob; u0 = [sys.x => 2.0, sys.y => 3.0, sys.z => 4.0])

0 commit comments

Comments
 (0)