diff --git a/Project.toml b/Project.toml index a392450..7fae40f 100644 --- a/Project.toml +++ b/Project.toml @@ -6,10 +6,15 @@ version = "0.3.42" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +[weakdeps] +PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" + +[extensions] +SymbolicIndexingInterfacePrettyTablesExt = "PrettyTables" + [compat] Accessors = "0.1.36" Aqua = "0.8" diff --git a/ext/SymbolicIndexingInterfacePrettyTablesExt/SymbolicIndexingInterfacePrettyTablesExt.jl b/ext/SymbolicIndexingInterfacePrettyTablesExt/SymbolicIndexingInterfacePrettyTablesExt.jl new file mode 100644 index 0000000..ba00172 --- /dev/null +++ b/ext/SymbolicIndexingInterfacePrettyTablesExt/SymbolicIndexingInterfacePrettyTablesExt.jl @@ -0,0 +1,48 @@ +module SymbolicIndexingInterfacePrettyTablesExt + +using SymbolicIndexingInterface +using SymbolicIndexingInterface: ParameterIndexingProxy, parameter_symbols, symbolic_type, + ArraySymbolic, getp +using PrettyTables + +# Override the fallback implementation with the PrettyTables version +function SymbolicIndexingInterface.show_params( + io::IO, pip::ParameterIndexingProxy; num_rows = 20, + show_all = false, scalarize = true, kwargs...) + params = Any[] + vals = Any[] + for p in parameter_symbols(pip.wrapped) + if symbolic_type(p) === ArraySymbolic() && scalarize + val = getp(pip.wrapped, p)(pip.wrapped) + for (_p, _v) in zip(collect(p), val) + push!(params, _p) + push!(vals, _v) + end + else + push!(params, p) + val = getp(pip.wrapped, p)(pip.wrapped) + push!(vals, val) + end + end + + num_shown = if show_all + length(params) + else + if num_rows > length(params) + length(params) + else + num_rows + end + end + + pretty_table(io, [params[1:num_shown] vals[1:num_shown]]; + header = ["Parameter", "Value"], + kwargs...) + + if num_shown < length(params) + println(io, + "$num_shown of $(length(params)) params shown. To show all the parameters, call `show_params(io, ps, show_all = true)`. Adjust the number of rows with the num_rows kwarg. Consult `show_params` docstring for more options.") + end +end + +end diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index 40d49b8..740af70 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -4,7 +4,6 @@ using RuntimeGeneratedFunctions import StaticArraysCore: MArray, similar_type import ArrayInterface using Accessors: @reset -using PrettyTables # for pretty printing RuntimeGeneratedFunctions.init(@__MODULE__) diff --git a/src/parameter_indexing_proxy.jl b/src/parameter_indexing_proxy.jl index 09c87d7..247349d 100644 --- a/src/parameter_indexing_proxy.jl +++ b/src/parameter_indexing_proxy.jl @@ -29,9 +29,12 @@ Method for customizing the table output. Keyword args: - num_rows - show_all: whether to show all parameters. Overrides `num_rows`. - scalarize: whether to scalarize array symbolics in the table output. -- kwargs... are passed to the pretty_table call. +- kwargs... are passed to the pretty_table call (if PrettyTables is loaded). """ -function show_params(io::IO, pip::ParameterIndexingProxy; num_rows = 20, +show_params(io, pip; kwargs...) = _show_params(io, pip; kwargs...) + +# Fallback implementation when PrettyTables is not loaded +function _show_params(io::IO, pip::ParameterIndexingProxy; num_rows = 20, show_all = false, scalarize = true, kwargs...) params = Any[] vals = Any[] @@ -59,9 +62,14 @@ function show_params(io::IO, pip::ParameterIndexingProxy; num_rows = 20, end end - pretty_table(io, [params[1:num_shown] vals[1:num_shown]]; - header = ["Parameter", "Value"], - kwargs...) + # Fallback implementation without PrettyTables + println(io, "Parameter Indexing Proxy") + println(io, "=" ^ 50) + println(io, "Parameter | Value") + println(io, "-" ^ 50) + for i in 1:num_shown + println(io, rpad(string(params[i]), 24) * " | " * string(vals[i])) + end if num_shown < length(params) println(io,