Skip to content

Commit 34b74df

Browse files
authored
losses tests (#212)
1 parent f5be653 commit 34b74df

File tree

7 files changed

+372
-112
lines changed

7 files changed

+372
-112
lines changed

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ include("test_generic_hybrid_model.jl")
1212
# Include SplitData tests
1313
include("test_split_data_train.jl")
1414
include("test_autodiff_backend.jl")
15+
include("test_loss_types.jl")
16+
include("test_show_loss_types.jl")
17+
include("test_compute_loss.jl")
18+
include("test_loss_fn.jl")
1519

1620
@testset "LinearHM" begin
1721
# test model instantiation

test/test_compute_loss.jl

Lines changed: 1 addition & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,110 +1,6 @@
1-
using Test
2-
using EasyHybrid
1+
using EasyHybrid: _compute_loss, PerTarget, _apply_loss, loss_fn
32
using Statistics
43
using DimensionalData
5-
using EasyHybrid: _compute_loss, PerTarget, _apply_loss, loss_fn
6-
7-
@testset "LoggingLoss" begin
8-
@testset "Constructor defaults" begin
9-
logging = LoggingLoss()
10-
@test loss_types(logging) == [:mse]
11-
@test training_loss(logging) == :mse
12-
@test extra_loss(logging) === nothing
13-
@test logging.agg == sum
14-
@test logging.train_mode == true
15-
end
16-
17-
@testset "Custom constructor" begin
18-
# Simple custom loss function
19-
custom_loss(ŷ, y) = mean(abs2, ŷ .- y)
20-
21-
# Loss function with args
22-
weighted_loss(ŷ, y, w) = w * mean(abs2, ŷ .- y)
23-
24-
# Loss function with kwargs
25-
scaled_loss(ŷ, y; scale = 1.0) = scale * mean(abs2, ŷ .- y)
26-
# extra loss
27-
extra(ŷ) = sum(abs, ŷ)
28-
29-
@testset "Basic custom constructor" begin
30-
logging = LoggingLoss(
31-
loss_types = [:mse, :mae],
32-
training_loss = :mae,
33-
agg = mean,
34-
extra_loss = extra,
35-
train_mode = false
36-
)
37-
@test loss_types(logging) == [:mse, :mae]
38-
@test training_loss(logging) == :mae
39-
@test extra_loss(logging) === extra
40-
@test logging.agg == mean
41-
@test logging.train_mode == false
42-
end
43-
44-
@testset "Mixed loss_types" begin
45-
logging = LoggingLoss(
46-
loss_types = [:mse, custom_loss, (weighted_loss, (0.5,)), (scaled_loss, (scale = 2.0,))],
47-
training_loss = :mse,
48-
agg = sum
49-
)
50-
@test length(loss_types(logging)) == 4
51-
@test loss_types(logging)[1] == :mse
52-
@test loss_types(logging)[2] == custom_loss
53-
@test loss_types(logging)[3] == (weighted_loss, (0.5,))
54-
@test loss_types(logging)[4] == (scaled_loss, (scale = 2.0,))
55-
end
56-
57-
@testset "PerTarget Mixed loss_types" begin
58-
logging = LoggingLoss(
59-
loss_types = [:mse],
60-
training_loss = (
61-
:mse,
62-
custom_loss,
63-
(weighted_loss, (0.5,)),
64-
(scaled_loss, (scale = 2.0,)),
65-
),
66-
agg = sum
67-
)
68-
69-
@test length(training_loss(logging)) == 4
70-
@test first(training_loss(logging)) == :mse
71-
@test training_loss(logging)[2] == custom_loss
72-
@test training_loss(logging)[3] == (weighted_loss, (0.5,))
73-
@test last(training_loss(logging)) == (scaled_loss, (scale = 2.0,))
74-
end
75-
76-
@testset "Custom training_loss variations" begin
77-
# Function as training_loss
78-
logging = LoggingLoss(
79-
loss_types = [:mse],
80-
training_loss = custom_loss
81-
)
82-
@test training_loss(logging) == custom_loss
83-
84-
# Tuple with args as training_loss
85-
logging = LoggingLoss(
86-
loss_types = [:mse],
87-
training_loss = (weighted_loss, (0.5,))
88-
)
89-
@test training_loss(logging) == (weighted_loss, (0.5,))
90-
91-
# Tuple with kwargs as training_loss
92-
logging = LoggingLoss(
93-
loss_types = [:mse],
94-
training_loss = (scaled_loss, (scale = 2.0,))
95-
)
96-
@test training_loss(logging) == (scaled_loss, (scale = 2.0,))
97-
98-
# Tuple with both args and kwargs
99-
complex_loss(x, y, w; scale = 1.0) = scale * w * mean(abs2, x .- y)
100-
logging = LoggingLoss(
101-
loss_types = [:mse],
102-
training_loss = (complex_loss, (0.5,), (scale = 2.0,))
103-
)
104-
@test training_loss(logging) == (complex_loss, (0.5,), (scale = 2.0,))
105-
end
106-
end
107-
end
1084

1095
@testset "_compute_loss" begin
1106
# Test data setup

test/test_generic_hybrid_model.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
using EasyHybrid
2-
using Test
31
using Lux
42
using Random
53
using AxisKeys

test/test_loss_fn.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
using Test
2-
using EasyHybrid
31
using Statistics
42

53
@testset "loss_fn methods" begin
@@ -31,7 +29,7 @@ using Statistics
3129
@test loss_fn(ŷ, y, y_nan, Val(:r2)) r^2
3230

3331
# NSE test
34-
nse = sum((ŷ .- y) .^ 2) / sum((y .- mean(y)) .^ 2)
32+
nse = 1 - sum((ŷ .- y) .^ 2) / sum((y .- mean(y)) .^ 2)
3533
@test loss_fn(ŷ, y, y_nan, Val(:nse)) nse
3634
end
3735

test/test_loss_types.jl

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
using Test
2+
using EasyHybrid
3+
using EasyHybrid: SymbolicLoss, FunctionLoss, ParameterizedLoss, ExtraLoss, _format_loss_spec, LoggingLoss, PerTarget
4+
using EasyHybrid: loss_name, loss_spec, _to_extra_loss_spec
5+
6+
identity_fn(x) = x
7+
simple_fn(x, a) = x + a
8+
simple_fn(x, y; scale = 1.0) = scale * (x + y)
9+
10+
@testset "LoggingLoss" begin
11+
@testset "Constructor defaults" begin
12+
logging = LoggingLoss()
13+
@test loss_types(logging) == [:mse]
14+
@test training_loss(logging) == :mse
15+
@test extra_loss(logging) === nothing
16+
@test logging.agg == sum
17+
@test logging.train_mode == true
18+
end
19+
20+
@testset "Custom constructor" begin
21+
# Simple custom loss function
22+
custom_loss(ŷ, y) = mean(abs2, ŷ .- y)
23+
24+
# Loss function with args
25+
weighted_loss(ŷ, y, w) = w * mean(abs2, ŷ .- y)
26+
27+
# Loss function with kwargs
28+
scaled_loss(ŷ, y; scale = 1.0) = scale * mean(abs2, ŷ .- y)
29+
# extra loss
30+
extra(ŷ) = sum(abs, ŷ)
31+
32+
@testset "Basic custom constructor" begin
33+
logging = LoggingLoss(
34+
loss_types = [:mse, :mae],
35+
training_loss = :mae,
36+
agg = mean,
37+
extra_loss = extra,
38+
train_mode = false
39+
)
40+
@test loss_types(logging) == [:mse, :mae]
41+
@test training_loss(logging) == :mae
42+
@test extra_loss(logging) === extra
43+
@test logging.agg == mean
44+
@test logging.train_mode == false
45+
end
46+
47+
@testset "Mixed loss_types" begin
48+
logging = LoggingLoss(
49+
loss_types = [:mse, custom_loss, (weighted_loss, (0.5,)), (scaled_loss, (scale = 2.0,))],
50+
training_loss = :mse,
51+
agg = sum
52+
)
53+
@test length(loss_types(logging)) == 4
54+
@test loss_types(logging)[1] == :mse
55+
@test loss_types(logging)[2] == custom_loss
56+
@test loss_types(logging)[3] == (weighted_loss, (0.5,))
57+
@test loss_types(logging)[4] == (scaled_loss, (scale = 2.0,))
58+
end
59+
60+
@testset "PerTarget Mixed loss_types" begin
61+
logging = LoggingLoss(
62+
loss_types = [:mse],
63+
training_loss = (
64+
:mse,
65+
custom_loss,
66+
(weighted_loss, (0.5,)),
67+
(scaled_loss, (scale = 2.0,)),
68+
),
69+
agg = sum
70+
)
71+
72+
@test length(training_loss(logging)) == 4
73+
@test first(training_loss(logging)) == :mse
74+
@test training_loss(logging)[2] == custom_loss
75+
@test training_loss(logging)[3] == (weighted_loss, (0.5,))
76+
@test last(training_loss(logging)) == (scaled_loss, (scale = 2.0,))
77+
end
78+
79+
@testset "Custom training_loss variations" begin
80+
# Function as training_loss
81+
logging = LoggingLoss(
82+
loss_types = [:mse],
83+
training_loss = custom_loss
84+
)
85+
@test training_loss(logging) == custom_loss
86+
87+
# Tuple with args as training_loss
88+
logging = LoggingLoss(
89+
loss_types = [:mse],
90+
training_loss = (weighted_loss, (0.5,))
91+
)
92+
@test training_loss(logging) == (weighted_loss, (0.5,))
93+
94+
# Tuple with kwargs as training_loss
95+
logging = LoggingLoss(
96+
loss_types = [:mse],
97+
training_loss = (scaled_loss, (scale = 2.0,))
98+
)
99+
@test training_loss(logging) == (scaled_loss, (scale = 2.0,))
100+
101+
# Tuple with both args and kwargs
102+
complex_loss(x, y, w; scale = 1.0) = scale * w * mean(abs2, x .- y)
103+
logging = LoggingLoss(
104+
loss_types = [:mse],
105+
training_loss = (complex_loss, (0.5,), (scale = 2.0,))
106+
)
107+
@test training_loss(logging) == (complex_loss, (0.5,), (scale = 2.0,))
108+
end
109+
end
110+
end
111+
112+
113+
@testset "ParameterizedLoss constructors" begin
114+
@testset "Basic constructor" begin
115+
pl = ParameterizedLoss(simple_fn)
116+
@test pl.f === simple_fn
117+
@test pl.args == ()
118+
@test pl.kwargs == NamedTuple()
119+
end
120+
121+
@testset "Constructor with args" begin
122+
pl = ParameterizedLoss(simple_fn, (1, 2))
123+
@test pl.f === simple_fn
124+
@test pl.args == (1, 2)
125+
@test pl.kwargs == NamedTuple()
126+
end
127+
128+
@testset "Constructor with kwargs" begin
129+
pl = ParameterizedLoss(simple_fn, (scale = 2.0,))
130+
@test pl.f === simple_fn
131+
@test pl.args == ()
132+
@test pl.kwargs == (scale = 2.0,)
133+
end
134+
end
135+
136+
@testset "_to_extra_loss_spec edge cases" begin
137+
@testset "Nothing returns ExtraLoss(nothing)" begin
138+
el = _to_extra_loss_spec(nothing)
139+
@test el isa ExtraLoss
140+
@test el.f === nothing
141+
end
142+
end
143+
144+
@testset "loss_name edge cases" begin
145+
@testset "SymbolicLoss returns name" begin
146+
@test loss_name(SymbolicLoss(:mse)) === :mse
147+
end
148+
149+
@testset "Other LossSpecs return nothing" begin
150+
@test loss_name(FunctionLoss(simple_fn)) === nothing
151+
@test loss_name(ParameterizedLoss(simple_fn)) === nothing
152+
@test loss_name(ExtraLoss(simple_fn)) === nothing
153+
end
154+
end
155+
156+
157+
@testset "loss_spec edge cases" begin
158+
@testset "SymbolicLoss" begin
159+
ls = SymbolicLoss(:mse)
160+
@test loss_spec(ls) == :mse
161+
end
162+
163+
@testset "FunctionLoss" begin
164+
fl = FunctionLoss(simple_fn)
165+
@test loss_spec(fl) === simple_fn
166+
end
167+
168+
@testset "ParameterizedLoss" begin
169+
pl = ParameterizedLoss(simple_fn, (1,), (scale = 2.0,))
170+
@test loss_spec(pl) == (simple_fn, (1,), (scale = 2.0,))
171+
end
172+
173+
@testset "ExtraLoss" begin
174+
el = ExtraLoss(simple_fn)
175+
@test loss_spec(el) === simple_fn
176+
end
177+
178+
@testset "PerTarget" begin
179+
pt = PerTarget((SymbolicLoss(:mse), SymbolicLoss(:mae)))
180+
result = loss_spec(pt)
181+
@test result isa PerTarget
182+
@test result.losses == (:mse, :mae)
183+
end
184+
end
185+
186+
@testset "PerTarget edge cases" begin
187+
@testset "Empty PerTarget" begin
188+
pt_empty = PerTarget(())
189+
@test length(pt_empty) == 0
190+
@test iterate(pt_empty) === nothing
191+
@test_throws ArgumentError first(pt_empty)
192+
@test_throws BoundsError last(pt_empty)
193+
end
194+
195+
@testset "Single-element PerTarget" begin
196+
pt_single = PerTarget((SymbolicLoss(:mse),))
197+
@test length(pt_single) == 1
198+
@test first(pt_single) == SymbolicLoss(:mse)
199+
@test last(pt_single) == SymbolicLoss(:mse)
200+
end
201+
202+
@testset "Iteration" begin
203+
pt = PerTarget((SymbolicLoss(:mse), SymbolicLoss(:mae)))
204+
vals = []
205+
for l in pt
206+
push!(vals, l)
207+
end
208+
@test vals == [SymbolicLoss(:mse), SymbolicLoss(:mae)]
209+
end
210+
end
211+
212+
@testset "PerTarget Base methods" begin
213+
@testset "Base.length" begin
214+
pt = PerTarget((SymbolicLoss(:mse), SymbolicLoss(:mae), FunctionLoss(identity_fn)))
215+
@test length(pt) == 3
216+
217+
pt_single = PerTarget((SymbolicLoss(:mse),))
218+
@test length(pt_single) == 1
219+
end
220+
221+
@testset "Base.getindex" begin
222+
pt = PerTarget((SymbolicLoss(:mse), SymbolicLoss(:mae), FunctionLoss(identity_fn)))
223+
@test pt[1] == SymbolicLoss(:mse)
224+
@test pt[2] == SymbolicLoss(:mae)
225+
@test pt[3] == FunctionLoss(identity_fn)
226+
end
227+
228+
@testset "Base.first" begin
229+
pt = PerTarget((SymbolicLoss(:mse), SymbolicLoss(:mae)))
230+
@test first(pt) == SymbolicLoss(:mse)
231+
end
232+
233+
@testset "Base.last" begin
234+
pt = PerTarget((SymbolicLoss(:mse), SymbolicLoss(:mae)))
235+
@test last(pt) == SymbolicLoss(:mae)
236+
end
237+
238+
@testset "Base.keys" begin
239+
pt = PerTarget((SymbolicLoss(:mse), SymbolicLoss(:mae)))
240+
@test keys(pt) == keys(pt.losses)
241+
@test collect(keys(pt)) == [1, 2]
242+
end
243+
244+
@testset "Base.eltype" begin
245+
pt = PerTarget((SymbolicLoss(:mse), SymbolicLoss(:mae)))
246+
@test eltype(pt) == eltype(pt.losses)
247+
end
248+
end

0 commit comments

Comments
 (0)