Skip to content

Commit 2a3e88d

Browse files
Merge pull request SciML#721 from AayushSabharwal/as/getu-everywhere
refactor: use `getu`/`setu` for all indexing
2 parents e3a0de8 + 10f71ab commit 2a3e88d

File tree

12 files changed

+651
-505
lines changed

12 files changed

+651
-505
lines changed

.github/workflows/Downstream.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ jobs:
6262
repository: ${{ matrix.package.user }}/${{ matrix.package.repo }}
6363
path: downstream
6464
- name: Load this and run the downstream tests
65-
shell: julia --color=yes --project=downstream {0}
65+
shell: julia --color=yes --project=downstream --depwarn=yes {0}
6666
run: |
6767
using Pkg
6868
try

.github/workflows/Tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ jobs:
3434
- "Python"
3535
uses: "SciML/.github/.github/workflows/tests.yml@v1"
3636
with:
37+
julia-runtest-depwarn: "yes"
3738
group: "${{ matrix.group }}"
3839
julia-version: "${{ matrix.version }}"
3940
secrets: "inherit"

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ Reexport = "1"
8383
RuntimeGeneratedFunctions = "0.5.12"
8484
SciMLOperators = "0.3.7"
8585
SciMLStructures = "1.1"
86+
StableRNGs = "1.0"
8687
StaticArrays = "1.7"
8788
StaticArraysCore = "1.4"
8889
Statistics = "1.10"
@@ -106,11 +107,12 @@ PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
106107
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
107108
RCall = "6f49c342-dc21-5d91-9882-a32aef131414"
108109
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
110+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
109111
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
110112
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
111113
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
112114
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
113115
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
114116

115117
[targets]
116-
test = ["Pkg", "Plots", "UnicodePlots", "PyCall", "PythonCall", "SafeTestsets", "Test", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "ModelingToolkit", "OrdinaryDiffEq", "ForwardDiff"]
118+
test = ["Pkg", "Plots", "UnicodePlots", "PyCall", "PythonCall", "SafeTestsets", "Test", "StableRNGs", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "ModelingToolkit", "OrdinaryDiffEq", "ForwardDiff"]

src/integrator_interface.jl

Lines changed: 21 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -465,46 +465,20 @@ function Base.getproperty(A::DEIntegrator, sym::Symbol)
465465
end
466466
end
467467

468-
Base.@propagate_inbounds function _getindex(A::DEIntegrator,
469-
::NotSymbolic,
470-
I::Union{Int, AbstractArray{Int},
471-
CartesianIndex, Colon, BitArray,
472-
AbstractArray{Bool}}...)
473-
A.u[I...]
474-
end
475-
476-
Base.@propagate_inbounds function _getindex(A::DEIntegrator, ::ScalarSymbolic, sym)
477-
if is_variable(A, sym)
478-
return A[variable_index(A, sym)]
479-
elseif is_parameter(A, sym)
480-
error("Indexing with parameters is deprecated. Use `getp(sys, $sym)(integrator)` for parameter indexing")
481-
elseif is_independent_variable(A, sym)
482-
return A.t
483-
elseif is_observed(A, sym)
484-
return SymbolicIndexingInterface.observed(A, sym)(A.u, A.p, A.t)
485-
else
486-
error("Tried to index integrator with a Symbol that was not found in the system.")
468+
Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, sym)
469+
if is_parameter(A, sym)
470+
error("Indexing with parameters is deprecated. Use `integrator.ps[$sym]` for parameter indexing.")
487471
end
472+
return getu(A, sym)(A)
488473
end
489474

490-
Base.@propagate_inbounds function _getindex(A::DEIntegrator, ::ArraySymbolic, sym)
491-
return A[collect(sym)]
492-
end
493-
494-
Base.@propagate_inbounds function _getindex(
495-
A::DEIntegrator, ::ScalarSymbolic, sym::Union{Tuple, AbstractArray})
496-
return getindex.((A,), sym)
497-
end
498-
499-
Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, sym)
500-
symtype = symbolic_type(sym)
501-
elsymtype = symbolic_type(eltype(sym))
502-
503-
if symtype != NotSymbolic()
504-
return _getindex(A, symtype, sym)
505-
else
506-
return _getindex(A, elsymtype, sym)
475+
Base.@propagate_inbounds function Base.getindex(
476+
A::DEIntegrator, sym::Union{AbstractArray, Tuple})
477+
if symbolic_type(sym) == NotSymbolic() && any(x -> is_parameter(A, x), sym) ||
478+
is_parameter(A, sym)
479+
error("Indexing with parameters is deprecated. Use `integrator.ps[$sym]` for parameter indexing.")
507480
end
481+
return getu(A, sym)(A)
508482
end
509483

510484
Base.@propagate_inbounds function Base.getindex(
@@ -522,25 +496,18 @@ function observed(A::DEIntegrator, sym)
522496
end
523497

524498
function Base.setindex!(A::DEIntegrator, val, sym)
525-
has_sys(A.f) ||
526-
error("Invalid indexing of integrator: Integrator does not support indexing without a system")
527-
if symbolic_type(sym) == ScalarSymbolic()
528-
if is_variable(A, sym)
529-
set_state!(A, val, variable_index(A, sym))
530-
elseif is_parameter(A, sym)
531-
error("Parameter indexing is deprecated. Use `setp(sys, $sym)(integrator, $val)` to set parameter value.")
532-
else
533-
error("Invalid indexing of integrator: $sym is not a state or parameter, it may be an observed variable.")
534-
end
535-
return A
536-
elseif symbolic_type(sym) == ArraySymbolic()
537-
setindex!.((A,), val, collect(sym))
538-
return A
539-
else
540-
sym isa AbstractArray || error("Invalid indexing of integrator")
541-
setindex!.((A,), val, sym)
542-
return A
499+
if is_parameter(A, sym)
500+
error("Indexing with parameters is deprecated. Use `integrator.ps[$sym] = $val` for parameter indexing.")
501+
end
502+
setu(A, sym)(A, val)
503+
end
504+
505+
function Base.setindex!(A::DEIntegrator, val, sym::Union{AbstractArray, Tuple})
506+
if symbolic_type(sym) == NotSymbolic() && any(x -> is_parameter(A, x), sym) ||
507+
is_parameter(A, sym)
508+
error("Indexing with parameters is deprecated. Use `integrator.ps[$sym] = $val` for parameter indexing.")
543509
end
510+
setu(A, sym)(A, val)
544511
end
545512

546513
### Integrator traits

src/problems/problem_interface.jl

Lines changed: 27 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -38,51 +38,38 @@ Base.@propagate_inbounds function Base.getindex(
3838
return getindex(prob, all_variable_symbols(prob))
3939
end
4040

41-
Base.@propagate_inbounds function Base.getindex(prob::AbstractSciMLProblem, sym)
42-
if symbolic_type(sym) == ScalarSymbolic()
43-
if is_variable(prob, sym)
44-
return state_values(prob, variable_index(prob, sym))
45-
elseif is_parameter(prob, sym)
46-
error("Indexing with parameters is deprecated. Use `getp(prob, $sym)(prob)` for parameter indexing.")
47-
elseif is_independent_variable(prob, sym)
48-
return current_time(prob)
49-
elseif is_observed(prob, sym)
50-
obs = SymbolicIndexingInterface.observed(prob, sym)
51-
if is_time_dependent(prob)
52-
return obs(state_values(prob), parameter_values(prob), current_time(prob))
53-
else
54-
return obs(state_values(prob), parameter_values(prob))
55-
end
56-
else
57-
error("Invalid indexing of problem: $sym is not a state, parameter, or independent variable")
58-
end
59-
elseif symbolic_type(sym) == ArraySymbolic()
60-
return map(s -> prob[s], collect(sym))
61-
else
62-
sym isa AbstractArray || error("Invalid indexing of problem")
63-
return map(s -> prob[s], sym)
41+
Base.@propagate_inbounds function Base.getindex(A::AbstractSciMLProblem, sym)
42+
if is_parameter(A, sym)
43+
error("Indexing with parameters is deprecated. Use `prob.ps[$sym]` for parameter indexing.")
6444
end
45+
return getu(A, sym)(A)
46+
end
47+
48+
Base.@propagate_inbounds function Base.getindex(
49+
A::AbstractSciMLProblem, sym::Union{AbstractArray, Tuple})
50+
if symbolic_type(sym) == NotSymbolic() && any(x -> is_parameter(A, x), sym) ||
51+
is_parameter(A, sym)
52+
error("Indexing with parameters is deprecated. Use `prob.ps[$sym]` for parameter indexing.")
53+
end
54+
return getu(A, sym)(A)
6555
end
6656

6757
function Base.setindex!(prob::AbstractSciMLProblem, args...; kwargs...)
6858
___internal_setindex!(prob::AbstractSciMLProblem, args...; kwargs...)
6959
end
70-
function ___internal_setindex!(prob::AbstractSciMLProblem, val, sym)
71-
if symbolic_type(sym) == ScalarSymbolic()
72-
if is_variable(prob, sym)
73-
set_state!(prob, val, variable_index(prob, sym))
74-
elseif is_parameter(prob, sym)
75-
error("Indexing with parameters is deprecated. Use `setp(prob, $sym)(prob, $val)` to set parameter value.")
76-
else
77-
error("Invalid indexing of problem: $sym is not a state or parameter, it may be an observed variable.")
78-
end
79-
return prob
80-
elseif symbolic_type(sym) == ArraySymbolic()
81-
setindex!.((prob,), val, collect(sym))
82-
return prob
83-
else
84-
sym isa AbstractArray || error("Invalid indexing of problem")
85-
setindex!.((prob,), val, sym)
86-
return prob
60+
61+
function ___internal_setindex!(A::AbstractSciMLProblem, val, sym)
62+
if is_parameter(A, sym)
63+
error("Indexing with parameters is deprecated. Use `prob.ps[$sym] = $val` for parameter indexing.")
64+
end
65+
return setu(A, sym)(A, val)
66+
end
67+
68+
function ___internal_setindex!(
69+
A::AbstractSciMLProblem, val, sym::Union{AbstractArray, Tuple})
70+
if symbolic_type(sym) == NotSymbolic() && any(x -> is_parameter(A, x), sym) ||
71+
is_parameter(A, sym)
72+
error("Indexing with parameters is deprecated. Use `prob.ps[$sym] = $val` for parameter indexing.")
8773
end
74+
return setu(A, sym)(A, val)
8875
end

src/solutions/optimization_solutions.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,13 @@ function reinit!(cache::SciMLBase.AbstractOptimizationCache; p = missing,
174174
return cache
175175
end
176176

177-
SymbolicIndexingInterface.parameter_values(x::AbstractOptimizationCache) = x.p
177+
function SymbolicIndexingInterface.parameter_values(x::AbstractOptimizationCache)
178+
if has_reinit(x)
179+
x.reinit_cache.p
180+
else
181+
x.p
182+
end
183+
end
178184
SymbolicIndexingInterface.symbolic_container(x::AbstractOptimizationCache) = x.f
179185

180186
get_p(sol::OptimizationSolution) = sol.cache.p

src/solutions/solution_interface.jl

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -63,22 +63,19 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, :
6363
end
6464

6565
Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, sym)
66-
if symbolic_type(sym) == ScalarSymbolic()
67-
if is_variable(A, sym)
68-
return A[variable_index(A, sym)]
69-
elseif is_parameter(A, sym)
70-
error("Indexing with parameters is deprecated. Use `getp(sys, $sym)(sol)` for parameter indexing.")
71-
elseif is_observed(A, sym)
72-
return SymbolicIndexingInterface.observed(A, sym)(A.u, parameter_values(A))
73-
else
74-
error("Tried to index solution with a Symbol that was not found in the system.")
75-
end
76-
elseif symbolic_type(sym) == ArraySymbolic()
77-
return A[collect(sym)]
78-
else
79-
sym isa AbstractArray || error("Invalid indexing of solution")
80-
return getindex.((A,), sym)
66+
if is_parameter(A, sym)
67+
error("Indexing with parameters is deprecated. Use `sol.ps[$sym]` for parameter indexing.")
68+
end
69+
return getu(A, sym)(A)
70+
end
71+
72+
Base.@propagate_inbounds function Base.getindex(
73+
A::AbstractNoTimeSolution, sym::Union{AbstractArray, Tuple})
74+
if symbolic_type(sym) == NotSymbolic() && any(x -> is_parameter(A, x), sym) ||
75+
is_parameter(A, sym)
76+
error("Indexing with parameters is deprecated. Use `sol.ps[$sym]` for parameter indexing.")
8177
end
78+
return getu(A, sym)(A)
8279
end
8380

8481
Base.@propagate_inbounds function Base.getindex(

test/downstream/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
1313
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1414
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
1515
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
16+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
17+
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
18+
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
1619
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
1720
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
1821
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
@@ -34,6 +37,7 @@ RecursiveArrayTools = "3"
3437
SciMLBase = "2"
3538
SciMLSensitivity = "7.11"
3639
SciMLStructures = "1.1"
40+
SteadyStateDiffEq = "2.2"
3741
Sundials = "4.11"
3842
SymbolicIndexingInterface = "0.3"
3943
SymbolicUtils = "<1.6, 2"

0 commit comments

Comments
 (0)