Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.3.37"
[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

Expand All @@ -14,6 +15,7 @@ Accessors = "0.1.36"
Aqua = "0.8"
ArrayInterface = "7.9"
Pkg = "1"
PrettyTables = "2.4.0"
RuntimeGeneratedFunctions = "0.5.12"
SafeTestsets = "0.0.1"
StaticArrays = "1.9"
Expand Down
3 changes: 2 additions & 1 deletion src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using RuntimeGeneratedFunctions
import StaticArraysCore: MArray, similar_type
import ArrayInterface
using Accessors: @reset
using PrettyTables # for pretty printing

RuntimeGeneratedFunctions.init(@__MODULE__)

Expand Down Expand Up @@ -44,7 +45,7 @@ include("batched_interface.jl")
export ProblemState
include("problem_state.jl")

export ParameterIndexingProxy
export ParameterIndexingProxy, showparams
include("parameter_indexing_proxy.jl")

export remake_buffer
Expand Down
50 changes: 50 additions & 0 deletions src/parameter_indexing_proxy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,53 @@ end
function Base.setindex!(p::ParameterIndexingProxy, val, idx)
return setp(p.wrapped, idx)(p.wrapped, val)
end

function Base.show(io::IO, ::MIME"text/plain", pip::ParameterIndexingProxy)
showparams(io, pip; num_rows = 20, show_all = false, scalarize = true)
end

"""
showparams(io::IO, pip::ParameterIndexingProxy; num_rows = 20, show_all = false, scalarize = true, kwargs...)

Method for customizing the table output. Keyword args:
- num_rows
- show_all: whether to show all parameters. Overrides `num_rows`.
- scalarize: whether to scalarize array symbolics in the table output.
- kwargs... are passed to the pretty_table call.
"""
function showparams(io::IO, pip::ParameterIndexingProxy; num_rows = 20, show_all = false, scalarize = true, kwargs...)
params = Any[]
vals = Any[]
for p in parameter_symbols(pip.wrapped)
if symbolic_type(p) === ArraySymbolic() && scalarize
val = getp(pip.wrapped, p)(pip.wrapped)
for (_p, _v) in zip(collect(p), val)
push!(params, _p)
push!(vals, _v)
end
else
push!(params, p)
val = getp(pip.wrapped, p)(pip.wrapped)
push!(vals, val)
end
end

num_shown = if show_all
length(params)
else
if num_rows > length(params)
length(params)
else
num_rows
end
end

pretty_table(io, [params[1:num_shown] vals[1:num_shown]];
header=["Parameter", "Value"],
kwargs...)

if num_shown < length(params)
println(io,
"$num_shown of $(length(params)) params shown. To show all the parameters, call `showparams(io, ps, show_all = true)`. Adjust the number of rows with the num_rows kwarg. Consult `showparams` docstring for more options.")
end
end
Loading