From 01b76ab64bfbeae754ed901bd45978370e17067d Mon Sep 17 00:00:00 2001 From: vyudu Date: Fri, 24 Jan 2025 12:46:25 -0500 Subject: [PATCH 1/7] init --- Project.toml | 2 ++ src/SymbolicIndexingInterface.jl | 1 + src/parameter_indexing_proxy.jl | 17 +++++++++++++++++ 3 files changed, 20 insertions(+) diff --git a/Project.toml b/Project.toml index 181b196a..6bd28138 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +Term = "22787eb5-b846-44ae-b979-8e399b8463ab" [compat] Accessors = "0.1.36" @@ -18,6 +19,7 @@ RuntimeGeneratedFunctions = "0.5.12" SafeTestsets = "0.0.1" StaticArrays = "1.9" StaticArraysCore = "1.4" +Term = "2.0.7" Test = "1" Zygote = "0.6.67" julia = "1.10" diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index f444fb47..40062cd5 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -4,6 +4,7 @@ using RuntimeGeneratedFunctions import StaticArraysCore: MArray, similar_type import ArrayInterface using Accessors: @reset +import Term: Table # for pretty-printing RuntimeGeneratedFunctions.init(@__MODULE__) diff --git a/src/parameter_indexing_proxy.jl b/src/parameter_indexing_proxy.jl index cc0925d2..7159c9e6 100644 --- a/src/parameter_indexing_proxy.jl +++ b/src/parameter_indexing_proxy.jl @@ -17,3 +17,20 @@ end function Base.setindex!(p::ParameterIndexingProxy, val, idx) return setp(p.wrapped, idx)(p.wrapped, val) end + +function Base.show(io::IO, pip::ParameterIndexingProxy; kwargs...) + params = Any[] + vals = Any[] + for p in parameter_symbols(pip.wrapped) + push!(params, p) + val = getp(pip.wrapped, p)(pip.wrapped) + push!(vals, val) + end + + print( + Table([params vals]; + box=:SIMPLE, + header=["Parameter", "Value"], + kwargs...) + ) +end From 48c0a6f9a147e5a40f3fdcebf5ffb6945e839164 Mon Sep 17 00:00:00 2001 From: vyudu Date: Fri, 21 Feb 2025 10:40:46 -0800 Subject: [PATCH 2/7] refactor to PrettyTables --- Project.toml | 2 ++ src/SymbolicIndexingInterface.jl | 2 +- src/parameter_indexing_proxy.jl | 53 ++++++++++++++++++++++++++------ 3 files changed, 46 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index 6bd28138..5e8a4b5b 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.3.37" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Term = "22787eb5-b846-44ae-b979-8e399b8463ab" @@ -15,6 +16,7 @@ Accessors = "0.1.36" Aqua = "0.8" ArrayInterface = "7.9" Pkg = "1" +PrettyTables = "2.4.0" RuntimeGeneratedFunctions = "0.5.12" SafeTestsets = "0.0.1" StaticArrays = "1.9" diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 40062cd5..dbd01ddc 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -4,7 +4,7 @@ using RuntimeGeneratedFunctions import StaticArraysCore: MArray, similar_type import ArrayInterface using Accessors: @reset -import Term: Table # for pretty-printing +using PrettyTables # for pretty printing RuntimeGeneratedFunctions.init(@__MODULE__) diff --git a/src/parameter_indexing_proxy.jl b/src/parameter_indexing_proxy.jl index 7159c9e6..fd65a8e6 100644 --- a/src/parameter_indexing_proxy.jl +++ b/src/parameter_indexing_proxy.jl @@ -18,19 +18,52 @@ function Base.setindex!(p::ParameterIndexingProxy, val, idx) return setp(p.wrapped, idx)(p.wrapped, val) end -function Base.show(io::IO, pip::ParameterIndexingProxy; kwargs...) +function Base.show(io::IO, ::MIME"text/plain", pip::ParameterIndexingProxy) + showparams(io, pip; num_rows = 20, show_all = false, scalarize = true) +end + +""" + showparams(io::IO, pip::ParameterIndexingProxy; num_rows = 20, show_all = false, scalarize = true, kwargs...) + +Method for customizing the table output. Keyword args: +- num_rows +- show_all: whether to show all parameters +- scalarize: whether to scalarize array symbolics in the table output. +- kwargs... are passed to the pretty_table call. +""" +function showparams(io::IO, pip::ParameterIndexingProxy; num_rows = 20, show_all = false, scalarize = true, kwargs...) params = Any[] vals = Any[] for p in parameter_symbols(pip.wrapped) - push!(params, p) - val = getp(pip.wrapped, p)(pip.wrapped) - push!(vals, val) + if symbolic_type(p) === ArraySymbolic() && scalarize + val = getp(pip.wrapped, p)(pip.wrapped) + for (_p, _v) in zip(collect(p), val) + push!(params, _p) + push!(vals, _v) + end + else + push!(params, p) + val = getp(pip.wrapped, p)(pip.wrapped) + push!(vals, val) + end end - print( - Table([params vals]; - box=:SIMPLE, - header=["Parameter", "Value"], - kwargs...) - ) + num_shown = if show_all + length(params) + else + if num_rows > length(params) + length(params) + else + num_rows + end + end + + pretty_table(io, [params[1:num_shown] vals[1:num_shown]]; + header=["Parameter", "Value"], + kwargs...) + + if num_shown < length(params) + println(io, + "$num_shown of $(length(params)) params shown. To show all the parameters, call `showparams(io, ps, show_all = true)`. Adjust the number of rows with the num_rows kwarg. Consult `showparams` docstring for more options.") + end end From e99f7dfb623bc2cf771b66200e132dd1e030cba3 Mon Sep 17 00:00:00 2001 From: vyudu Date: Fri, 21 Feb 2025 10:41:35 -0800 Subject: [PATCH 3/7] remove Term as dependency --- Project.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/Project.toml b/Project.toml index 5e8a4b5b..788a5fa5 100644 --- a/Project.toml +++ b/Project.toml @@ -9,7 +9,6 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" -Term = "22787eb5-b846-44ae-b979-8e399b8463ab" [compat] Accessors = "0.1.36" @@ -21,7 +20,6 @@ RuntimeGeneratedFunctions = "0.5.12" SafeTestsets = "0.0.1" StaticArrays = "1.9" StaticArraysCore = "1.4" -Term = "2.0.7" Test = "1" Zygote = "0.6.67" julia = "1.10" From e1c73594ad528c5d7468ed21a57aca124db45610 Mon Sep 17 00:00:00 2001 From: vyudu Date: Fri, 21 Feb 2025 10:47:04 -0800 Subject: [PATCH 4/7] export showparams --- src/SymbolicIndexingInterface.jl | 2 +- src/parameter_indexing_proxy.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index dbd01ddc..26a43b63 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -45,7 +45,7 @@ include("batched_interface.jl") export ProblemState include("problem_state.jl") -export ParameterIndexingProxy +export ParameterIndexingProxy, showparams include("parameter_indexing_proxy.jl") export remake_buffer diff --git a/src/parameter_indexing_proxy.jl b/src/parameter_indexing_proxy.jl index fd65a8e6..c22b29fb 100644 --- a/src/parameter_indexing_proxy.jl +++ b/src/parameter_indexing_proxy.jl @@ -27,7 +27,7 @@ end Method for customizing the table output. Keyword args: - num_rows -- show_all: whether to show all parameters +- show_all: whether to show all parameters. Overrides `num_rows`. - scalarize: whether to scalarize array symbolics in the table output. - kwargs... are passed to the pretty_table call. """ From 2bfcbed25217f74f4f0c35ffb8509ece03f6d0e5 Mon Sep 17 00:00:00 2001 From: vyudu Date: Mon, 24 Feb 2025 10:37:35 -0500 Subject: [PATCH 5/7] Format and add docs --- docs/src/api.md | 1 + docs/src/complete_sii.md | 5 +++++ src/parameter_indexing_proxy.jl | 23 ++++++++++++----------- 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 29150c0a..2b6d08de 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -92,6 +92,7 @@ getp setp setp_oop ParameterIndexingProxy +show_params ``` #### Parameter timeseries diff --git a/docs/src/complete_sii.md b/docs/src/complete_sii.md index c75cdc80..ac25f974 100644 --- a/docs/src/complete_sii.md +++ b/docs/src/complete_sii.md @@ -311,6 +311,11 @@ integrator.ps[:b] = 3.0 setp(integrator, :b)(integrator, 3.0) # functionally the same as above ``` +The parameters will display as a table: +```@example show_params +integrator.ps +``` + ## Parameter Timeseries If a solution object includes modified parameter values (such as through callbacks) during the diff --git a/src/parameter_indexing_proxy.jl b/src/parameter_indexing_proxy.jl index c22b29fb..43755932 100644 --- a/src/parameter_indexing_proxy.jl +++ b/src/parameter_indexing_proxy.jl @@ -23,7 +23,7 @@ function Base.show(io::IO, ::MIME"text/plain", pip::ParameterIndexingProxy) end """ - showparams(io::IO, pip::ParameterIndexingProxy; num_rows = 20, show_all = false, scalarize = true, kwargs...) + show_params(io::IO, pip::ParameterIndexingProxy; num_rows = 20, show_all = false, scalarize = true, kwargs...) Method for customizing the table output. Keyword args: - num_rows @@ -31,7 +31,8 @@ Method for customizing the table output. Keyword args: - scalarize: whether to scalarize array symbolics in the table output. - kwargs... are passed to the pretty_table call. """ -function showparams(io::IO, pip::ParameterIndexingProxy; num_rows = 20, show_all = false, scalarize = true, kwargs...) +function showparams(io::IO, pip::ParameterIndexingProxy; num_rows = 20, + show_all = false, scalarize = true, kwargs...) params = Any[] vals = Any[] for p in parameter_symbols(pip.wrapped) @@ -49,21 +50,21 @@ function showparams(io::IO, pip::ParameterIndexingProxy; num_rows = 20, show_all end num_shown = if show_all + length(params) + else + if num_rows > length(params) length(params) else - if num_rows > length(params) - length(params) - else - num_rows - end + num_rows end + end pretty_table(io, [params[1:num_shown] vals[1:num_shown]]; - header=["Parameter", "Value"], - kwargs...) + header = ["Parameter", "Value"], + kwargs...) if num_shown < length(params) - println(io, - "$num_shown of $(length(params)) params shown. To show all the parameters, call `showparams(io, ps, show_all = true)`. Adjust the number of rows with the num_rows kwarg. Consult `showparams` docstring for more options.") + println(io, + "$num_shown of $(length(params)) params shown. To show all the parameters, call `show_params(io, ps, show_all = true)`. Adjust the number of rows with the num_rows kwarg. Consult `show_params` docstring for more options.") end end From 0d92e7f96ecc3d5a1171fe8f3fe2ab70a4056de0 Mon Sep 17 00:00:00 2001 From: vyudu Date: Mon, 24 Feb 2025 12:04:10 -0500 Subject: [PATCH 6/7] add to documentation --- docs/src/complete_sii.md | 27 ++++++++++++++------------- src/SymbolicIndexingInterface.jl | 2 +- src/parameter_indexing_proxy.jl | 4 ++-- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/docs/src/complete_sii.md b/docs/src/complete_sii.md index ac25f974..f6d8ef5f 100644 --- a/docs/src/complete_sii.md +++ b/docs/src/complete_sii.md @@ -4,7 +4,8 @@ This tutorial will show how to define the entire Symbolic Indexing Interface on an `ExampleSystem`: -```julia +```@example implementing_sii +using SymbolicIndexingInterface struct ExampleSystem state_index::Dict{Symbol,Int} parameter_index::Dict{Symbol,Int} @@ -24,7 +25,7 @@ supports specific functionality. Consider the following struct, which needs to i These are the simple functions which describe how to turn symbols into indices. -```julia +```@example implementing_sii function SymbolicIndexingInterface.is_variable(sys::ExampleSystem, sym) haskey(sys.state_index, sym) end @@ -65,7 +66,7 @@ end SymbolicIndexingInterface.constant_structure(::ExampleSystem) = true -function SymbolicIndexingInterface.all_solvable_symbols(sys::ExampleSystem) +function SymbolicIndexingInterface.all_variable_symbols(sys::ExampleSystem) return vcat( collect(keys(sys.state_index)), collect(keys(sys.observed)), @@ -74,7 +75,7 @@ end function SymbolicIndexingInterface.all_symbols(sys::ExampleSystem) return vcat( - all_solvable_symbols(sys), + all_variable_symbols(sys), collect(keys(sys.parameter_index)), sys.independent_variable === nothing ? Symbol[] : sys.independent_variable ) @@ -90,7 +91,7 @@ end These are for handling symbolic expressions and generating equations which are not directly in the solution vector. -```julia +```@example implementing_sii using RuntimeGeneratedFunctions RuntimeGeneratedFunctions.init(@__MODULE__) @@ -167,7 +168,7 @@ not typically useful for solution objects, it may be useful for integrators. Typ the default implementations for `getp` and `setp` will suffice, and manually defining them is not necessary. -```julia +```@example implementing_sii function SymbolicIndexingInterface.parameter_values(sys::ExampleSystem) sys.p end @@ -183,7 +184,7 @@ the system's symbols. This also requires that the type implement Consider the following `ExampleIntegrator` -```julia +```@example implementing_sii mutable struct ExampleIntegrator u::Vector{Float64} p::Vector{Float64} @@ -199,8 +200,8 @@ SymbolicIndexingInterface.current_time(sys::ExampleIntegrator) = sys.t ``` Then the following example would work: -```julia -sys = ExampleSystem(Dict(:x => 1, :y => 2, :z => 3), Dict(:a => 1, :b => 2), :t, Dict()) +```@example implementing_sii +sys = ExampleSystem(Dict(:x => 1, :y => 2, :z => 3), Dict(:a => 1, :b => 2), :t, Dict(), Dict()) integrator = ExampleIntegrator([1.0, 2.0, 3.0], [4.0, 5.0], 6.0, sys) getx = getsym(sys, :x) getx(integrator) # 1.0 @@ -289,7 +290,7 @@ interface and allows using [`getp`](@ref) and [`setp`](@ref) to get and set para values. This allows for a cleaner interface for parameter indexing. Consider the following example for `ExampleIntegrator`: -```julia +```@example implementing_sii function Base.getproperty(obj::ExampleIntegrator, sym::Symbol) if sym === :ps return ParameterIndexingProxy(obj) @@ -301,8 +302,8 @@ end This enables the following API: -```julia -integrator = ExampleIntegrator([1.0, 2.0, 3.0], [4.0, 5.0], 6.0, Dict(:x => 1, :y => 2, :z => 3), Dict(:a => 1, :b => 2), :t) +```@example implementing_sii +integrator = ExampleIntegrator([1.0, 2.0, 3.0], [4.0, 5.0], 6.0, sys) integrator.ps[:a] # 4.0 getp(integrator, :a)(integrator) # functionally the same as above @@ -312,7 +313,7 @@ setp(integrator, :b)(integrator, 3.0) # functionally the same as above ``` The parameters will display as a table: -```@example show_params +```@example implementing_sii integrator.ps ``` diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 26a43b63..40d49b8e 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -45,7 +45,7 @@ include("batched_interface.jl") export ProblemState include("problem_state.jl") -export ParameterIndexingProxy, showparams +export ParameterIndexingProxy, show_params include("parameter_indexing_proxy.jl") export remake_buffer diff --git a/src/parameter_indexing_proxy.jl b/src/parameter_indexing_proxy.jl index 43755932..09c87d7d 100644 --- a/src/parameter_indexing_proxy.jl +++ b/src/parameter_indexing_proxy.jl @@ -19,7 +19,7 @@ function Base.setindex!(p::ParameterIndexingProxy, val, idx) end function Base.show(io::IO, ::MIME"text/plain", pip::ParameterIndexingProxy) - showparams(io, pip; num_rows = 20, show_all = false, scalarize = true) + show_params(io, pip; num_rows = 20, show_all = false, scalarize = true) end """ @@ -31,7 +31,7 @@ Method for customizing the table output. Keyword args: - scalarize: whether to scalarize array symbolics in the table output. - kwargs... are passed to the pretty_table call. """ -function showparams(io::IO, pip::ParameterIndexingProxy; num_rows = 20, +function show_params(io::IO, pip::ParameterIndexingProxy; num_rows = 20, show_all = false, scalarize = true, kwargs...) params = Any[] vals = Any[] From 9ea685e4be410cbec0c4000f53d601c69a4d9e5c Mon Sep 17 00:00:00 2001 From: vyudu Date: Mon, 24 Feb 2025 12:52:03 -0500 Subject: [PATCH 7/7] add RGF as dep --- docs/Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/Project.toml b/docs/Project.toml index 593250b9..dfd88684 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -3,6 +3,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" [compat]