Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.3.37"
[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

Expand All @@ -14,6 +15,7 @@ Accessors = "0.1.36"
Aqua = "0.8"
ArrayInterface = "7.9"
Pkg = "1"
PrettyTables = "2.4.0"
RuntimeGeneratedFunctions = "0.5.12"
SafeTestsets = "0.0.1"
StaticArrays = "1.9"
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"

[compat]
Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ getp
setp
setp_oop
ParameterIndexingProxy
show_params
```

#### Parameter timeseries
Expand Down
30 changes: 18 additions & 12 deletions docs/src/complete_sii.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
This tutorial will show how to define the entire Symbolic Indexing Interface on an
`ExampleSystem`:

```julia
```@example implementing_sii
using SymbolicIndexingInterface
struct ExampleSystem
state_index::Dict{Symbol,Int}
parameter_index::Dict{Symbol,Int}
Expand All @@ -24,7 +25,7 @@ supports specific functionality. Consider the following struct, which needs to i

These are the simple functions which describe how to turn symbols into indices.

```julia
```@example implementing_sii
function SymbolicIndexingInterface.is_variable(sys::ExampleSystem, sym)
haskey(sys.state_index, sym)
end
Expand Down Expand Up @@ -65,7 +66,7 @@ end

SymbolicIndexingInterface.constant_structure(::ExampleSystem) = true

function SymbolicIndexingInterface.all_solvable_symbols(sys::ExampleSystem)
function SymbolicIndexingInterface.all_variable_symbols(sys::ExampleSystem)
return vcat(
collect(keys(sys.state_index)),
collect(keys(sys.observed)),
Expand All @@ -74,7 +75,7 @@ end

function SymbolicIndexingInterface.all_symbols(sys::ExampleSystem)
return vcat(
all_solvable_symbols(sys),
all_variable_symbols(sys),
collect(keys(sys.parameter_index)),
sys.independent_variable === nothing ? Symbol[] : sys.independent_variable
)
Expand All @@ -90,7 +91,7 @@ end
These are for handling symbolic expressions and generating equations which are not directly
in the solution vector.

```julia
```@example implementing_sii
using RuntimeGeneratedFunctions
RuntimeGeneratedFunctions.init(@__MODULE__)

Expand Down Expand Up @@ -167,7 +168,7 @@ not typically useful for solution objects, it may be useful for integrators. Typ
the default implementations for `getp` and `setp` will suffice, and manually defining
them is not necessary.

```julia
```@example implementing_sii
function SymbolicIndexingInterface.parameter_values(sys::ExampleSystem)
sys.p
end
Expand All @@ -183,7 +184,7 @@ the system's symbols. This also requires that the type implement

Consider the following `ExampleIntegrator`

```julia
```@example implementing_sii
mutable struct ExampleIntegrator
u::Vector{Float64}
p::Vector{Float64}
Expand All @@ -199,8 +200,8 @@ SymbolicIndexingInterface.current_time(sys::ExampleIntegrator) = sys.t
```

Then the following example would work:
```julia
sys = ExampleSystem(Dict(:x => 1, :y => 2, :z => 3), Dict(:a => 1, :b => 2), :t, Dict())
```@example implementing_sii
sys = ExampleSystem(Dict(:x => 1, :y => 2, :z => 3), Dict(:a => 1, :b => 2), :t, Dict(), Dict())
integrator = ExampleIntegrator([1.0, 2.0, 3.0], [4.0, 5.0], 6.0, sys)
getx = getsym(sys, :x)
getx(integrator) # 1.0
Expand Down Expand Up @@ -289,7 +290,7 @@ interface and allows using [`getp`](@ref) and [`setp`](@ref) to get and set para
values. This allows for a cleaner interface for parameter indexing. Consider the
following example for `ExampleIntegrator`:

```julia
```@example implementing_sii
function Base.getproperty(obj::ExampleIntegrator, sym::Symbol)
if sym === :ps
return ParameterIndexingProxy(obj)
Expand All @@ -301,8 +302,8 @@ end

This enables the following API:

```julia
integrator = ExampleIntegrator([1.0, 2.0, 3.0], [4.0, 5.0], 6.0, Dict(:x => 1, :y => 2, :z => 3), Dict(:a => 1, :b => 2), :t)
```@example implementing_sii
integrator = ExampleIntegrator([1.0, 2.0, 3.0], [4.0, 5.0], 6.0, sys)

integrator.ps[:a] # 4.0
getp(integrator, :a)(integrator) # functionally the same as above
Expand All @@ -311,6 +312,11 @@ integrator.ps[:b] = 3.0
setp(integrator, :b)(integrator, 3.0) # functionally the same as above
```

The parameters will display as a table:
```@example implementing_sii
integrator.ps
```

## Parameter Timeseries

If a solution object includes modified parameter values (such as through callbacks) during the
Expand Down
3 changes: 2 additions & 1 deletion src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using RuntimeGeneratedFunctions
import StaticArraysCore: MArray, similar_type
import ArrayInterface
using Accessors: @reset
using PrettyTables # for pretty printing

RuntimeGeneratedFunctions.init(@__MODULE__)

Expand Down Expand Up @@ -44,7 +45,7 @@ include("batched_interface.jl")
export ProblemState
include("problem_state.jl")

export ParameterIndexingProxy
export ParameterIndexingProxy, show_params
include("parameter_indexing_proxy.jl")

export remake_buffer
Expand Down
51 changes: 51 additions & 0 deletions src/parameter_indexing_proxy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,54 @@ end
function Base.setindex!(p::ParameterIndexingProxy, val, idx)
return setp(p.wrapped, idx)(p.wrapped, val)
end

function Base.show(io::IO, ::MIME"text/plain", pip::ParameterIndexingProxy)
show_params(io, pip; num_rows = 20, show_all = false, scalarize = true)
end

"""
show_params(io::IO, pip::ParameterIndexingProxy; num_rows = 20, show_all = false, scalarize = true, kwargs...)

Method for customizing the table output. Keyword args:
- num_rows
- show_all: whether to show all parameters. Overrides `num_rows`.
- scalarize: whether to scalarize array symbolics in the table output.
- kwargs... are passed to the pretty_table call.
"""
function show_params(io::IO, pip::ParameterIndexingProxy; num_rows = 20,
show_all = false, scalarize = true, kwargs...)
params = Any[]
vals = Any[]
for p in parameter_symbols(pip.wrapped)
if symbolic_type(p) === ArraySymbolic() && scalarize
val = getp(pip.wrapped, p)(pip.wrapped)
for (_p, _v) in zip(collect(p), val)
push!(params, _p)
push!(vals, _v)
end
else
push!(params, p)
val = getp(pip.wrapped, p)(pip.wrapped)
push!(vals, val)
end
end

num_shown = if show_all
length(params)
else
if num_rows > length(params)
length(params)
else
num_rows
end
end

pretty_table(io, [params[1:num_shown] vals[1:num_shown]];
header = ["Parameter", "Value"],
kwargs...)

if num_shown < length(params)
println(io,
"$num_shown of $(length(params)) params shown. To show all the parameters, call `show_params(io, ps, show_all = true)`. Adjust the number of rows with the num_rows kwarg. Consult `show_params` docstring for more options.")
end
end
Loading