|
| 1 | +""" |
| 2 | + compute_loss(HM, x, (y_t, y_nan), ps, st, logging::LoggingLoss) |
| 3 | +
|
| 4 | +Main loss function for hybrid models that handles both training and evaluation modes. |
| 5 | +
|
| 6 | +# Arguments |
| 7 | +- `HM`: The hybrid model (AbstractLuxContainerLayer or specific model type) |
| 8 | +- `x`: Input data for the model |
| 9 | +- `(y_t, y_nan)`: Tuple containing target values and NaN mask functions/arrays |
| 10 | +- `ps`: Model parameters |
| 11 | +- `st`: Model state |
| 12 | +- `logging`: LoggingLoss configuration |
| 13 | +
|
| 14 | +# Returns |
| 15 | +- In training mode (`logging.train_mode = true`): |
| 16 | + - `(loss_value, st)`: Single loss value and updated state |
| 17 | +- In evaluation mode (`logging.train_mode = false`): |
| 18 | + - `(loss_values, st, ŷ)`: NamedTuple of losses, state and predictions |
| 19 | +""" |
| 20 | +function compute_loss( |
| 21 | + HM::LuxCore.AbstractLuxContainerLayer, ps, st, (x, (y_t, y_nan)); |
| 22 | + logging::LoggingLoss |
| 23 | + ) |
| 24 | + |
| 25 | + targets = HM.targets |
| 26 | + ext_loss = extra_loss(logging) |
| 27 | + if logging.train_mode |
| 28 | + ŷ, st = HM(x, ps, st) |
| 29 | + loss_value = _compute_loss(ŷ, y_t, y_nan, targets, training_loss(logging), logging.agg) |
| 30 | + # Add extra_loss if provided |
| 31 | + if ext_loss !== nothing |
| 32 | + extra_loss_value = ext_loss(ŷ) |
| 33 | + loss_value = logging.agg([loss_value, extra_loss_value...]) |
| 34 | + end |
| 35 | + stats = NamedTuple() |
| 36 | + else |
| 37 | + ŷ, _ = HM(x, ps, LuxCore.testmode(st)) |
| 38 | + loss_value = _compute_loss(ŷ, y_t, y_nan, targets, loss_types(logging), logging.agg) |
| 39 | + # Add extra_loss entries if provided |
| 40 | + if ext_loss !== nothing |
| 41 | + extra_loss_values = ext_loss(ŷ) |
| 42 | + agg_extra_loss_value = logging.agg(extra_loss_values) |
| 43 | + loss_value = (; loss_value..., extra_loss = (; extra_loss_values..., Symbol(logging.agg) => agg_extra_loss_value)) |
| 44 | + end |
| 45 | + stats = (; ŷ...) |
| 46 | + end |
| 47 | + return loss_value, st, stats |
| 48 | +end |
| 49 | + |
| 50 | +function _compute_loss(ŷ, y, y_nan, targets, loss_spec, agg::Function) |
| 51 | + losses = assemble_loss(ŷ, y, y_nan, targets, loss_spec) |
| 52 | + return agg(losses) |
| 53 | +end |
| 54 | + |
| 55 | +function _compute_loss(ŷ, y, y_nan, targets, loss_types::Vector, agg::Function) |
| 56 | + out_loss_types = [ |
| 57 | + begin |
| 58 | + losses = assemble_loss(ŷ, y, y_nan, targets, loss_type) |
| 59 | + agg_loss = agg(losses) |
| 60 | + NamedTuple{(targets..., Symbol(agg))}([losses..., agg_loss]) |
| 61 | + end |
| 62 | + for loss_type in loss_types |
| 63 | + ] |
| 64 | + _names = [_loss_name(lt) for lt in loss_types] |
| 65 | + return NamedTuple{Tuple(_names)}([out_loss_types...]) |
| 66 | +end |
| 67 | + |
| 68 | +""" |
| 69 | + _compute_loss(ŷ, y, y_nan, targets, loss_spec, agg::Function) |
| 70 | + _compute_loss(ŷ, y, y_nan, targets, loss_types::Vector, agg::Function) |
| 71 | +
|
| 72 | +Compute the loss for the given predictions and targets using the specified training loss (or vector of losses) type and aggregation function. |
| 73 | +
|
| 74 | +# Arguments: |
| 75 | +- `ŷ`: Predicted values. |
| 76 | +- `y`: Target values. |
| 77 | +- `y_nan`: Mask for NaN values. |
| 78 | +- `targets`: The targets for which the loss is computed. |
| 79 | +- `loss_spec`: The loss type to use during training, e.g., `:mse`. |
| 80 | +- `loss_types::Vector`: A vector of loss types to compute, e.g., `[:mse, :mae]`. |
| 81 | +- `agg::Function`: The aggregation function to apply to the computed losses, e.g., `sum` or `mean`. |
| 82 | +
|
| 83 | +Returns a single loss value if `loss_spec` is provided, or a NamedTuple of losses for each type in `loss_types`. |
| 84 | +""" |
| 85 | +function _compute_loss end |
| 86 | + |
| 87 | +function assemble_loss(ŷ, y, y_nan, targets, loss_spec) |
| 88 | + return [ |
| 89 | + _apply_loss(ŷ[target], _get_target_y(y, target), _get_target_nan(y_nan, target), loss_spec) |
| 90 | + for target in targets |
| 91 | + ] |
| 92 | +end |
| 93 | + |
| 94 | +function assemble_loss(ŷ, y, y_nan, targets, loss_spec::PerTarget) |
| 95 | + @assert length(targets) == length(loss_spec.losses) "Length of targets and PerTarget losses tuple must match" |
| 96 | + losses = [ |
| 97 | + _apply_loss( |
| 98 | + ŷ, |
| 99 | + _get_target_y(y, target), |
| 100 | + _get_target_nan(y_nan, target), |
| 101 | + target, |
| 102 | + loss_t |
| 103 | + ) for (target, loss_t) in zip(targets, loss_spec.losses) |
| 104 | + ] |
| 105 | + return losses |
| 106 | +end |
| 107 | + |
| 108 | +function _apply_loss(ŷ, y, y_nan, loss_spec::Symbol) |
| 109 | + return loss_fn(ŷ, y, y_nan, Val(loss_spec)) |
| 110 | +end |
| 111 | + |
| 112 | +function _apply_loss(ŷ, y, y_nan, loss_spec::Function) |
| 113 | + return loss_fn(ŷ, y, y_nan, loss_spec) |
| 114 | +end |
| 115 | + |
| 116 | +function _apply_loss(ŷ, y, y_nan, loss_spec::Tuple) |
| 117 | + return loss_fn(ŷ, y, y_nan, loss_spec) |
| 118 | +end |
| 119 | +function _apply_loss(ŷ, y, y_nan, target, loss_spec) |
| 120 | + return _apply_loss(ŷ[target], y, y_nan, loss_spec) |
| 121 | +end |
| 122 | + |
| 123 | +""" |
| 124 | + _apply_loss(ŷ, y, y_nan, loss_spec) |
| 125 | +
|
| 126 | +Helper function to apply the appropriate loss function based on the specification type. |
| 127 | +
|
| 128 | +# Arguments |
| 129 | +- `ŷ`: Predictions for a single target |
| 130 | +- `y`: Target values for a single target |
| 131 | +- `y_nan`: NaN mask for a single target |
| 132 | +- `loss_spec`: Loss specification (Symbol, Function, or Tuple) |
| 133 | +
|
| 134 | +# Returns |
| 135 | +- Computed loss value |
| 136 | +""" |
| 137 | +function _apply_loss end |
| 138 | + |
| 139 | +_get_target_y(y, target) = y(target) |
| 140 | +_get_target_y(y::AbstractDimArray, target) = y[col = At(target)] # assumes the DimArray uses :col indexing |
| 141 | +_get_target_y(y::AbstractDimArray, targets::Vector) = y[col = At(targets)] # for multiple targets |
| 142 | + |
| 143 | +function _get_target_y(y::Tuple, target) |
| 144 | + y_obs, y_sigma = y |
| 145 | + sigma = y_sigma isa Number ? y_sigma : y_sigma(target) |
| 146 | + y_obs_val = _get_target_y(y_obs, target) |
| 147 | + return (y_obs_val, sigma) |
| 148 | +end |
| 149 | + |
| 150 | + |
| 151 | +""" |
| 152 | + _get_target_y(y, target) |
| 153 | +Helper function to extract target-specific values from `y`, handling cases where `y` can be a tuple of `(y_obs, y_sigma)`. |
| 154 | +""" |
| 155 | +function _get_target_y end |
| 156 | + |
| 157 | +_get_target_nan(y_nan, target) = y_nan(target) |
| 158 | +_get_target_nan(y_nan::AbstractDimArray, target) = y_nan[col = At(target)] # assumes the DimArray uses :col indexing |
| 159 | +_get_target_nan(y_nan::AbstractDimArray, targets::Vector) = y_nan[col = At(targets)] # for multiple targets |
| 160 | + |
| 161 | +""" |
| 162 | + _get_target_nan(y_nan, target) |
| 163 | +
|
| 164 | +Helper function to extract target-specific values from `y_nan`. |
| 165 | +""" |
| 166 | +function _get_target_nan end |
| 167 | + |
| 168 | +# Helper to generate meaningful names for loss types |
| 169 | +function _loss_name(loss_spec::Symbol) |
| 170 | + return loss_spec |
| 171 | +end |
| 172 | + |
| 173 | +function _loss_name(loss_spec::Function) |
| 174 | + raw_name = nameof(typeof(loss_spec)) |
| 175 | + clean_name = Symbol(replace(string(raw_name), "#" => "")) |
| 176 | + return clean_name |
| 177 | +end |
| 178 | + |
| 179 | +function _loss_name(loss_spec::Tuple) |
| 180 | + return _loss_name(loss_spec[1]) |
| 181 | +end |
0 commit comments