Skip to content

Commit 7f2377d

Browse files
authored
fix train show and tests (#214)
* fix train show and tests * m lines
1 parent 0325d31 commit 7f2377d

File tree

4 files changed

+214
-0
lines changed

4 files changed

+214
-0
lines changed

src/utils/show_train.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@ function _print_nested_keys(io::IO, nt::NamedTuple; indent = 4)
66
pad = " "^(maxkey - length(kstr) + 2)
77
if isa(v, NamedTuple)
88
println(io, prefix, kstr, pad, "(", join(propertynames(v), ", "), ")")
9+
elseif isa(v, Tuple)
10+
# Tuples need special handling - use length or show as scalar
11+
if length(v) == 0
12+
println(io, prefix, kstr, pad, "()")
13+
else
14+
printstyled(io, prefix * kstr * pad; color = 10)
15+
printstyled(io, "($(length(v)),)"; color = :light_black)
16+
println(io)
17+
end
918
else
1019
sz = size(v)
1120
if sz == ()
@@ -49,6 +58,9 @@ function Base.show(io::IO, ::MIME"text/plain", tr::TrainResults)
4958
println(io)
5059
_print_nested_keys(io, val; indent = 4)
5160
else
61+
# Print scalar values or other types
62+
print(io)
63+
printstyled(io, "$(val)"; color = :light_red)
5264
println(io)
5365
end
5466
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ include("test_loss_types.jl")
1616
include("test_show_loss_types.jl")
1717
include("test_compute_loss.jl")
1818
include("test_loss_fn.jl")
19+
include("test_show_train.jl")
20+
include("test_show_generic_hybrid.jl")
1921

2022
@testset "LinearHM" begin
2123
# test model instantiation

test/test_show_generic_hybrid.jl

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
using EasyHybrid: HybridParams, ParameterContainer, SingleNNHybridModel, MultiNNHybridModel
2+
using EasyHybrid: _print_field, _print_header, IndentedIO
3+
4+
@testset "show_generic.jl" begin
5+
6+
@testset "_print_field and _print_header" begin
7+
result = sprint(
8+
io -> begin
9+
_print_header(io, "Test Header", color = :blue)
10+
_print_field(io, "scalar", 42)
11+
_print_field(io, "bool_true", true)
12+
_print_field(io, "bool_false", false)
13+
_print_field(io, "namedtuple", (a = 1, b = 2.0))
14+
_print_field(io, "function", sum)
15+
end, context = :color => false
16+
)
17+
18+
@test occursin("Test Header", result)
19+
@test occursin("scalar", result) && occursin("42", result)
20+
@test occursin("bool_true", result) && occursin("bool_false", result)
21+
@test occursin("namedtuple", result) && occursin("a =", result) && occursin("b =", result)
22+
@test occursin("function", result) && occursin("sum", result)
23+
end
24+
25+
@testset "HybridParams show" begin
26+
params = (a = (1.0, 0.0, 2.0), b = (2.0, 1.0, 3.0))
27+
pc = ParameterContainer(params)
28+
hp = HybridParams{typeof(sum)}(pc)
29+
30+
# Compact show
31+
result_compact = sprint(show, hp, context = :color => false)
32+
@test occursin("HybridParams", result_compact) && occursin("ParameterContainer", result_compact)
33+
34+
# Text/plain show
35+
result_full = sprint(show, MIME"text/plain"(), hp, context = :color => false)
36+
@test occursin("Hybrid Parameters", result_full) # only check for the header
37+
end
38+
39+
@testset "ParameterContainer compact show" begin
40+
params = (a = (1.0, 0.0, 2.0), b = (2.0, 1.0, 3.0))
41+
pc = ParameterContainer(params)
42+
43+
result = sprint(show, pc, context = :color => false)
44+
@test occursin("ParameterContainer(a, b)", result)
45+
end
46+
47+
@testset "SingleNNHybridModel show" begin
48+
function test_model(; x1, a, b)
49+
return (; y_pred = a .* x1 .+ b)
50+
end
51+
52+
model = constructHybridModel(
53+
[:x1, :x2], [:x3], [:y], test_model,
54+
(a = (1.0, 0.0, 2.0), b = (2.0, 1.0, 3.0)),
55+
[:a], [:b];
56+
hidden_layers = [4, 4], activation = tanh
57+
)
58+
59+
result = sprint(show, MIME"text/plain"(), model, context = :color => false)
60+
61+
@test occursin("Hybrid Model (Single NN)", result)
62+
@test occursin("Neural Network:", result) && occursin("Configuration:", result)
63+
@test all(
64+
occursin.(
65+
[
66+
"predictors", "forcing", "targets", "mechanistic_model",
67+
"neural_param_names", "global_param_names", "scale_nn_outputs",
68+
"start_from_default", "config", "Parameters:",
69+
], Ref(result)
70+
)
71+
)
72+
end
73+
74+
@testset "MultiNNHybridModel show" begin
75+
function test_model(; x1, x2, x3, a, b, c, d)
76+
return (; obs = a .* x2 .+ d .* x1 .+ b)
77+
end
78+
79+
model = constructHybridModel(
80+
(a = [:x2, :x3], d = [:x1]), [:x1], [:obs], test_model,
81+
(a = (1.0, 0.0, 5.0), b = (2.0, 0.0, 10.0), c = (0.5, 0.0, 2.0), d = (0.5, 0.0, 2.0)),
82+
[:b]; # Only global_param_names, neural_param_names derived from predictors keys
83+
hidden_layers = [4, 4], activation = tanh
84+
)
85+
86+
result = sprint(show, MIME"text/plain"(), model, context = :color => false)
87+
88+
@test occursin("Hybrid Model (Multi NN)", result)
89+
@test occursin("Neural Networks:", result) && occursin("Configuration:", result)
90+
@test all(occursin.(["predictors", "a", "d", "forcing", "targets", "config", "Parameters:"], Ref(result)))
91+
end
92+
93+
@testset "IndentedIO" begin
94+
result = sprint(
95+
io -> begin
96+
ido = IndentedIO(io; indent = " ")
97+
println(ido, "line1")
98+
println(ido, "line2")
99+
end, context = :color => false
100+
)
101+
102+
@test occursin(" line1", result) && occursin(" line2", result)
103+
104+
# Test flush, isopen, close, readavailable methods
105+
io1 = IOBuffer()
106+
ido1 = IndentedIO(io1; indent = " ")
107+
@test isopen(ido1) == isopen(io1)
108+
print(ido1, "test")
109+
flush(ido1)
110+
@test String(take!(io1)) == " test"
111+
112+
# Test readavailable (for IOBuffer)
113+
io2 = IOBuffer("data")
114+
ido2 = IndentedIO(io2)
115+
@test readavailable(ido2) == UInt8[0x64, 0x61, 0x74, 0x61]
116+
117+
# Test close
118+
io3 = IOBuffer()
119+
ido3 = IndentedIO(io3)
120+
@test isopen(ido3)
121+
close(ido3)
122+
@test !isopen(ido3)
123+
end
124+
125+
end

test/test_show_train.jl

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
using Test
2+
using EasyHybrid
3+
using EasyHybrid: TrainResults
4+
using DataFrames
5+
6+
@testset "show_train.jl" begin
7+
8+
@testset "_print_nested_keys" begin
9+
# Test scalars, arrays, nested NamedTuples, empty and non-empty tuples
10+
nt = (
11+
scalar = 1.0,
12+
array = [1, 2, 3],
13+
nested = (x = 1.0, y = 2.0),
14+
empty_tuple = (),
15+
non_empty_tuple = (1.0, 2.0, 3.0), # Non-empty tuple to cover tuple length printing
16+
)
17+
result = sprint(io -> EasyHybrid._print_nested_keys(io, nt; indent = 4), context = :color => false)
18+
19+
@test occursin("scalar", result) && occursin("array", result) && occursin("nested", result)
20+
@test occursin("(x, y)", result) # Nested property names
21+
@test occursin("(3,)", result) # Array size
22+
@test occursin("non_empty_tuple", result) && occursin("(3,)", result) # Non-empty tuple length
23+
@test occursin("empty_tuple", result) && occursin("()", result) # Empty tuple
24+
end
25+
26+
@testset "Base.show for TrainResults" begin
27+
# Comprehensive test covering all field types
28+
train_history = [(mse = (reco = 1.0, sum = 0.5), r2 = (reco = 0.9, sum = 0.45))]
29+
val_history = [(mse = (reco = 1.1, sum = 0.55), r2 = (reco = 0.88, sum = 0.44))]
30+
ps_history = [(ϕ = (), monitor = (train = (), val = ()))]
31+
train_obs_pred = DataFrame(reco = [1.0, 2.0], index = [1, 2], reco_pred = [1.1, 2.1])
32+
val_obs_pred = DataFrame(reco = [3.0], index = [3], reco_pred = [3.1])
33+
train_diffs = (Q10 = [2.0], rb = [1.0, 2.0], parameters = (rb = [1.0], Q10 = [2.0]))
34+
val_diffs = (Q10 = [2.0], rb = [3.0], parameters = (rb = [3.0], Q10 = [2.0]))
35+
ps = ([1.0, 2.0],) # Tuple, (not really the full expected type, but it's a tuple)
36+
st = (st_nn = (layer_1 = (), layer_2 = ()), fixed = ())
37+
best_epoch = 42
38+
best_loss = 0.123
39+
40+
tr = TrainResults(
41+
train_history, val_history, ps_history, train_obs_pred, val_obs_pred,
42+
train_diffs, val_diffs, ps, st, best_epoch, best_loss
43+
)
44+
45+
result = sprint(show, MIME"text/plain"(), tr; context = :color => false)
46+
47+
# All fields present
48+
@test all(
49+
occursin.(
50+
[
51+
"train_history:", "val_history:", "ps_history:", "train_obs_pred:",
52+
"val_obs_pred:", "train_diffs:", "val_diffs:", "ps:", "st:",
53+
"best_epoch:", "best_loss:",
54+
], Ref(result)
55+
)
56+
)
57+
58+
# Array sizes and nested structures
59+
@test occursin("(1,)", result)
60+
@test occursin("(reco, sum)", result) && occursin("(train, val)", result)
61+
@test occursin("(rb, Q10)", result)
62+
63+
# DataFrame format
64+
@test occursin("DataFrame", result) && occursin("reco", result) && occursin("index", result)
65+
66+
# Scalar values printed
67+
@test occursin("42", result) && occursin("0.123", result)
68+
69+
# Empty arrays and tuples handled
70+
tr_empty = TrainResults([], [], [], DataFrame(), DataFrame(), nothing, nothing, (), (), 0, 0.0)
71+
result_empty = sprint(show, MIME"text/plain"(), tr_empty; context = :color => false)
72+
@test occursin("(0,)", result_empty) && occursin("best_epoch:", result_empty)
73+
end
74+
75+
end

0 commit comments

Comments
 (0)