Skip to content

Commit f5be653

Browse files
authored
per target (#210)
* per target * rename file * rename lossfn to compute_loss * fix name include * r f * orga split * tests PerTarget
1 parent 261c876 commit f5be653

File tree

12 files changed

+424
-329
lines changed

12 files changed

+424
-329
lines changed

docs/src/tutorials/losses.md

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
```@example loss
44
using EasyHybrid
5-
using EasyHybrid: compute_loss
5+
using EasyHybrid: _compute_loss
66
```
77

88
````@docs; canonical=false
9-
EasyHybrid.compute_loss
9+
EasyHybrid._compute_loss
1010
````
1111

1212
::: warning
@@ -21,12 +21,12 @@ EasyHybrid.compute_loss
2121
- Prefer `f(ŷ_masked, y_masked)` for custom losses; `y_masked` may be a vector or `(y, σ)`.
2222
- Use `Val(:metric)` only for predefined `loss_fn` variants.
2323
- Quick calls:
24-
- `compute_loss(..., :mse, sum)`: predefined
25-
- `compute_loss(..., custom_loss, sum)` : custom loss
26-
- `compute_loss(..., (f, (arg1, arg2, )), sum)`: additional arguments
27-
- `compute_loss(..., (f, (kw=val,)), sum)`: with keyword arguments
28-
- `compute_loss(..., (f, (arg1, ), (kw=val,)), sum)`: with additional arguments and keyword arguments
29-
- `compute_loss(..., (y, y_sigma), ..., custom_loss_uncertainty, sum)`: with uncertainties
24+
- `_compute_loss(..., :mse, sum)`: predefined
25+
- `_compute_loss(..., custom_loss, sum)` : custom loss
26+
- `_compute_loss(..., (f, (arg1, arg2, )), sum)`: additional arguments
27+
- `_compute_loss(..., (f, (kw=val,)), sum)`: with keyword arguments
28+
- `_compute_loss(..., (f, (arg1, ), (kw=val,)), sum)`: with additional arguments and keyword arguments
29+
- `_compute_loss(..., (y, y_sigma), ..., custom_loss_uncertainty, sum)`: with uncertainties
3030

3131
:::
3232

@@ -44,8 +44,8 @@ targets = [:t1, :t2]
4444
```
4545

4646
```@ansi loss
47-
mse_total = compute_loss(ŷ, y, y_nan, targets, :mse, sum) # total MSE across targets
48-
losses = compute_loss(ŷ, y, y_nan, targets, [:mse, :mae], sum) # multiple metrics in a NamedTuple
47+
mse_total = _compute_loss(ŷ, y, y_nan, targets, :mse, sum) # total MSE across targets
48+
losses = _compute_loss(ŷ, y, y_nan, targets, [:mse, :mae], sum) # multiple metrics in a NamedTuple
4949
```
5050

5151
### Custom functions, args, kwargs
@@ -63,10 +63,10 @@ nothing # hide
6363
Use variants:
6464

6565
```@ansi loss
66-
compute_loss(ŷ, y, y_nan, targets, custom_loss, sum)
67-
compute_loss(ŷ, y, y_nan, targets, (weighted_loss, (0.5,)), sum)
68-
compute_loss(ŷ, y, y_nan, targets, (scaled_loss, (scale=2.0,)), sum)
69-
compute_loss(ŷ, y, y_nan, targets, (complex_loss, (0.5,), (scale=2.0,)), sum)
66+
_compute_loss(ŷ, y, y_nan, targets, custom_loss, sum)
67+
_compute_loss(ŷ, y, y_nan, targets, (weighted_loss, (0.5,)), sum)
68+
_compute_loss(ŷ, y, y_nan, targets, (scaled_loss, (scale=2.0,)), sum)
69+
_compute_loss(ŷ, y, y_nan, targets, (complex_loss, (0.5,), (scale=2.0,)), sum)
7070
```
7171

7272
### Uncertainty-aware losses
@@ -90,13 +90,13 @@ Top-level usage (both `y` and `y_sigma` can be functions or containers):
9090

9191
```julia
9292
y_sigma(t) = t == :t1 ? [0.1, 0.2] : [0.2, 0.1]
93-
loss = compute_loss(ŷ, (y, y_sigma), y_nan, targets,
93+
loss = _compute_loss(ŷ, (y, y_sigma), y_nan, targets,
9494
custom_loss_uncertainty, sum)
9595
```
9696

9797
::: info Behavior
9898

99-
- `compute_loss` packs per-target `(y_vals_target, σ_target)` tuples and forwards them to `loss_fn`.
99+
- `_compute_loss` packs per-target `(y_vals_target, σ_target)` tuples and forwards them to `loss_fn`.
100100
- Predefined metrics use only `y_vals` when a `(y, σ)` tuple is supplied. (TODO)
101101

102102
:::

projects/RbQ10/Q10.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,14 @@ out_pinball = train(
8787
);
8888

8989
## legacy
90-
# ? test lossfn
90+
# ? test compute_loss
9191
ps, st = LuxCore.setup(Random.default_rng(), RbQ10)
9292
# the Tuple `ds_p, ds_t` is later used for batching in the `dataloader`.
9393
ds_p_f, ds_t = EasyHybrid.prepare_data(RbQ10, ds_keyed)
9494
ds_t_nan = .!isnan.(ds_t)
95-
ls = EasyHybrid.lossfn(RbQ10, ds_p_f, (ds_t, ds_t_nan), ps, st, LoggingLoss())
95+
ls = EasyHybrid.compute_loss(RbQ10, ds_p_f, (ds_t, ds_t_nan), ps, st, LoggingLoss())
9696

97-
ls_logs = EasyHybrid.lossfn(RbQ10, ds_p_f, (ds_t, ds_t_nan), ps, st, LoggingLoss(train_mode = false))
97+
ls_logs = EasyHybrid.compute_loss(RbQ10, ds_p_f, (ds_t, ds_t_nan), ps, st, LoggingLoss(train_mode = false))
9898

9999
# ? play with :Temp as predictors in NN, temperature sensitivity!
100100
# TODO: variance effect due to LSTM vs NN

projects/RbQ10/Q10_dd.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ using Zygote
3030
ps, st = LuxCore.setup(Random.default_rng(), RbQ10)
3131

3232
l, backtrace = Zygote.pullback(
33-
(ps) -> EasyHybrid.lossfn(
33+
(ps) -> EasyHybrid.compute_loss(
3434
RbQ10, ds_p_f, (ds_t, ds_t_nan), ps, st,
3535
EasyHybrid.LoggingLoss(training_loss = :mse, agg = sum)
3636
), ps
@@ -60,24 +60,24 @@ targets = RbQ10.targets
6060
# EasyHybrid.get_predictions_targets(RbQ10, ds_p_f, (ds_t, ds_t_nan), ps, st, targets)
6161
# ŷ, st_ = RbQ10(ds_p_f, ps, st)
6262

63-
# EasyHybrid.compute_loss(ŷ, ds_t, ds_t_nan, targets, :mse, sum)
63+
# EasyHybrid._compute_loss(ŷ, ds_t, ds_t_nan, targets, :mse, sum)
6464

65-
# ls = EasyHybrid.lossfn(RbQ10, ds_p_f, (ds_t, ds_t_nan), ps, st, LoggingLoss())
65+
# ls = EasyHybrid.compute_loss(RbQ10, ds_p_f, (ds_t, ds_t_nan), ps, st, LoggingLoss())
6666

6767

6868
## ! DimensionalData + ChainRulesCore
69-
# ? test lossfn
69+
# ? test compute_loss
7070
# ps, st = LuxCore.setup(Random.default_rng(), RbQ10)
7171

72-
ls = EasyHybrid.lossfn(RbQ10, ds_p_f, (ds_t, ds_t_nan), ps, st, LoggingLoss())
73-
ls_logs = EasyHybrid.lossfn(RbQ10, ds_p_f, (ds_t, ds_t_nan), ps, st, LoggingLoss(train_mode = false))
72+
ls = EasyHybrid.compute_loss(RbQ10, ds_p_f, (ds_t, ds_t_nan), ps, st, LoggingLoss())
73+
ls_logs = EasyHybrid.compute_loss(RbQ10, ds_p_f, (ds_t, ds_t_nan), ps, st, LoggingLoss(train_mode = false))
7474
acc_ = EasyHybrid.evaluate_acc(RbQ10, ds_p_f, ds_t, ds_t_nan, ps, st, [:mse, :r2], :mse, sum)
7575

7676
using Zygote, ChainRulesCore, DimensionalData
7777
using EasyHybrid
7878

7979
l, backtrace = Zygote.pullback(
80-
(ps) -> EasyHybrid.lossfn(
80+
(ps) -> EasyHybrid.compute_loss(
8181
RbQ10, ds_p_f, (ds_t, ds_t_nan), ps, st,
8282
EasyHybrid.LoggingLoss(training_loss = :mse, agg = sum)
8383
), ps

projects/RbQ10/Q10_lbfgs.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ end
4747
# the Tuple `ds_p, ds_t` is later used for batching in the `dataloader`.
4848
ds_p_f, ds_t = EasyHybrid.prepare_data(RbQ10, ds_keyed)
4949
ds_t_nan = .!isnan.(ds_t)
50-
ls = EasyHybrid.lossfn(RbQ10, ds_p_f, (ds_t, ds_t_nan), ps, st, LoggingLoss(train_mode = false))
50+
ls = EasyHybrid.compute_loss(RbQ10, ds_p_f, (ds_t, ds_t_nan), ps, st, LoggingLoss(train_mode = false))
5151

52-
ls2 = (p, data) -> EasyHybrid.lossfn(RbQ10, ds_p_f, (ds_t, ds_t_nan), p, st, LoggingLoss())[1]
52+
ls2 = (p, data) -> EasyHybrid.compute_loss(RbQ10, ds_p_f, (ds_t, ds_t_nan), p, st, LoggingLoss())[1]
5353

5454
dta = (ds_p_f, ds_t, ds_t_nan)
5555

projects/RbQ10/synthetic_example_bookchapter.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ dta = (ds_p_f, ds_t, ds_t_nan)
6565
dataloader = DataLoader((x_train, y_train, nan_train), batchsize = 512, shuffle = true);
6666

6767
# wrap loss function to get arguments as required by Optimization.jl
68-
ls2 = (p, data) -> EasyHybrid.lossfn(RbQ10, data[1], (data[2], data[3]), p, st, LoggingLoss())[1]
68+
ls2 = (p, data) -> EasyHybrid.compute_loss(RbQ10, data[1], (data[2], data[3]), p, st, LoggingLoss())[1]
6969

7070
# convert to Float64 for optimization
7171
ps_ca = ComponentArray(ps) .|> Float64

src/EasyHybrid.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,9 @@ include("utils/tools.jl")
5050
include("models/models.jl")
5151
include("utils/show_generic.jl")
5252
include("utils/synthetic_test_data.jl")
53-
include("utils/logging_loss.jl")
54-
include("utils/show_logging.jl")
53+
include("utils/compute_loss_types.jl")
54+
include("utils/show_loss_types.jl")
55+
include("utils/compute_loss.jl")
5556
include("utils/loss_fn.jl")
5657
include("plotrecipes.jl")
5758
include("train.jl")

src/train.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ function train(
200200
@info "Check the saved output (.png, .mp4, .jld2) from training at: $(tmp_folder)"
201201

202202
prog = Progress(nepochs, desc = "Training loss", enabled = show_progress)
203-
loss(hybridModel, ps, st, (x, y)) = lossfn(
203+
loss(hybridModel, ps, st, (x, y)) = compute_loss(
204204
hybridModel, ps, st, (x, y);
205205
logging = LoggingLoss(train_mode = true, loss_types = loss_types, training_loss = training_loss, extra_loss = extra_loss, agg = agg)
206206
)
@@ -366,7 +366,7 @@ function train(
366366
end
367367

368368
function evaluate_acc(ghm, x, y, y_no_nan, ps, st, loss_types, training_loss, extra_loss, agg)
369-
loss_val, sts, ŷ = lossfn(ghm, ps, st, (x, (y, y_no_nan)), logging = LoggingLoss(train_mode = false, loss_types = loss_types, training_loss = training_loss, extra_loss = extra_loss, agg = agg))
369+
loss_val, sts, ŷ = compute_loss(ghm, ps, st, (x, (y, y_no_nan)), logging = LoggingLoss(train_mode = false, loss_types = loss_types, training_loss = training_loss, extra_loss = extra_loss, agg = agg))
370370
return loss_val, sts, ŷ
371371
end
372372
function maybe_record_history(block, should_record, fig, output_path; framerate = 24)

src/utils/compute_loss.jl

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
"""
2+
compute_loss(HM, x, (y_t, y_nan), ps, st, logging::LoggingLoss)
3+
4+
Main loss function for hybrid models that handles both training and evaluation modes.
5+
6+
# Arguments
7+
- `HM`: The hybrid model (AbstractLuxContainerLayer or specific model type)
8+
- `x`: Input data for the model
9+
- `(y_t, y_nan)`: Tuple containing target values and NaN mask functions/arrays
10+
- `ps`: Model parameters
11+
- `st`: Model state
12+
- `logging`: LoggingLoss configuration
13+
14+
# Returns
15+
- In training mode (`logging.train_mode = true`):
16+
- `(loss_value, st)`: Single loss value and updated state
17+
- In evaluation mode (`logging.train_mode = false`):
18+
- `(loss_values, st, ŷ)`: NamedTuple of losses, state and predictions
19+
"""
20+
function compute_loss(
21+
HM::LuxCore.AbstractLuxContainerLayer, ps, st, (x, (y_t, y_nan));
22+
logging::LoggingLoss
23+
)
24+
25+
targets = HM.targets
26+
ext_loss = extra_loss(logging)
27+
if logging.train_mode
28+
ŷ, st = HM(x, ps, st)
29+
loss_value = _compute_loss(ŷ, y_t, y_nan, targets, training_loss(logging), logging.agg)
30+
# Add extra_loss if provided
31+
if ext_loss !== nothing
32+
extra_loss_value = ext_loss(ŷ)
33+
loss_value = logging.agg([loss_value, extra_loss_value...])
34+
end
35+
stats = NamedTuple()
36+
else
37+
ŷ, _ = HM(x, ps, LuxCore.testmode(st))
38+
loss_value = _compute_loss(ŷ, y_t, y_nan, targets, loss_types(logging), logging.agg)
39+
# Add extra_loss entries if provided
40+
if ext_loss !== nothing
41+
extra_loss_values = ext_loss(ŷ)
42+
agg_extra_loss_value = logging.agg(extra_loss_values)
43+
loss_value = (; loss_value..., extra_loss = (; extra_loss_values..., Symbol(logging.agg) => agg_extra_loss_value))
44+
end
45+
stats = (; ŷ...)
46+
end
47+
return loss_value, st, stats
48+
end
49+
50+
function _compute_loss(ŷ, y, y_nan, targets, loss_spec, agg::Function)
51+
losses = assemble_loss(ŷ, y, y_nan, targets, loss_spec)
52+
return agg(losses)
53+
end
54+
55+
function _compute_loss(ŷ, y, y_nan, targets, loss_types::Vector, agg::Function)
56+
out_loss_types = [
57+
begin
58+
losses = assemble_loss(ŷ, y, y_nan, targets, loss_type)
59+
agg_loss = agg(losses)
60+
NamedTuple{(targets..., Symbol(agg))}([losses..., agg_loss])
61+
end
62+
for loss_type in loss_types
63+
]
64+
_names = [_loss_name(lt) for lt in loss_types]
65+
return NamedTuple{Tuple(_names)}([out_loss_types...])
66+
end
67+
68+
"""
69+
_compute_loss(ŷ, y, y_nan, targets, loss_spec, agg::Function)
70+
_compute_loss(ŷ, y, y_nan, targets, loss_types::Vector, agg::Function)
71+
72+
Compute the loss for the given predictions and targets using the specified training loss (or vector of losses) type and aggregation function.
73+
74+
# Arguments:
75+
- `ŷ`: Predicted values.
76+
- `y`: Target values.
77+
- `y_nan`: Mask for NaN values.
78+
- `targets`: The targets for which the loss is computed.
79+
- `loss_spec`: The loss type to use during training, e.g., `:mse`.
80+
- `loss_types::Vector`: A vector of loss types to compute, e.g., `[:mse, :mae]`.
81+
- `agg::Function`: The aggregation function to apply to the computed losses, e.g., `sum` or `mean`.
82+
83+
Returns a single loss value if `loss_spec` is provided, or a NamedTuple of losses for each type in `loss_types`.
84+
"""
85+
function _compute_loss end
86+
87+
function assemble_loss(ŷ, y, y_nan, targets, loss_spec)
88+
return [
89+
_apply_loss(ŷ[target], _get_target_y(y, target), _get_target_nan(y_nan, target), loss_spec)
90+
for target in targets
91+
]
92+
end
93+
94+
function assemble_loss(ŷ, y, y_nan, targets, loss_spec::PerTarget)
95+
@assert length(targets) == length(loss_spec.losses) "Length of targets and PerTarget losses tuple must match"
96+
losses = [
97+
_apply_loss(
98+
ŷ,
99+
_get_target_y(y, target),
100+
_get_target_nan(y_nan, target),
101+
target,
102+
loss_t
103+
) for (target, loss_t) in zip(targets, loss_spec.losses)
104+
]
105+
return losses
106+
end
107+
108+
function _apply_loss(ŷ, y, y_nan, loss_spec::Symbol)
109+
return loss_fn(ŷ, y, y_nan, Val(loss_spec))
110+
end
111+
112+
function _apply_loss(ŷ, y, y_nan, loss_spec::Function)
113+
return loss_fn(ŷ, y, y_nan, loss_spec)
114+
end
115+
116+
function _apply_loss(ŷ, y, y_nan, loss_spec::Tuple)
117+
return loss_fn(ŷ, y, y_nan, loss_spec)
118+
end
119+
function _apply_loss(ŷ, y, y_nan, target, loss_spec)
120+
return _apply_loss(ŷ[target], y, y_nan, loss_spec)
121+
end
122+
123+
"""
124+
_apply_loss(ŷ, y, y_nan, loss_spec)
125+
126+
Helper function to apply the appropriate loss function based on the specification type.
127+
128+
# Arguments
129+
- `ŷ`: Predictions for a single target
130+
- `y`: Target values for a single target
131+
- `y_nan`: NaN mask for a single target
132+
- `loss_spec`: Loss specification (Symbol, Function, or Tuple)
133+
134+
# Returns
135+
- Computed loss value
136+
"""
137+
function _apply_loss end
138+
139+
_get_target_y(y, target) = y(target)
140+
_get_target_y(y::AbstractDimArray, target) = y[col = At(target)] # assumes the DimArray uses :col indexing
141+
_get_target_y(y::AbstractDimArray, targets::Vector) = y[col = At(targets)] # for multiple targets
142+
143+
function _get_target_y(y::Tuple, target)
144+
y_obs, y_sigma = y
145+
sigma = y_sigma isa Number ? y_sigma : y_sigma(target)
146+
y_obs_val = _get_target_y(y_obs, target)
147+
return (y_obs_val, sigma)
148+
end
149+
150+
151+
"""
152+
_get_target_y(y, target)
153+
Helper function to extract target-specific values from `y`, handling cases where `y` can be a tuple of `(y_obs, y_sigma)`.
154+
"""
155+
function _get_target_y end
156+
157+
_get_target_nan(y_nan, target) = y_nan(target)
158+
_get_target_nan(y_nan::AbstractDimArray, target) = y_nan[col = At(target)] # assumes the DimArray uses :col indexing
159+
_get_target_nan(y_nan::AbstractDimArray, targets::Vector) = y_nan[col = At(targets)] # for multiple targets
160+
161+
"""
162+
_get_target_nan(y_nan, target)
163+
164+
Helper function to extract target-specific values from `y_nan`.
165+
"""
166+
function _get_target_nan end
167+
168+
# Helper to generate meaningful names for loss types
169+
function _loss_name(loss_spec::Symbol)
170+
return loss_spec
171+
end
172+
173+
function _loss_name(loss_spec::Function)
174+
raw_name = nameof(typeof(loss_spec))
175+
clean_name = Symbol(replace(string(raw_name), "#" => ""))
176+
return clean_name
177+
end
178+
179+
function _loss_name(loss_spec::Tuple)
180+
return _loss_name(loss_spec[1])
181+
end

0 commit comments

Comments
 (0)