Skip to content

Commit 49efa3d

Browse files
siddhantjaincopybara-github
authored andcommitted
Add a new loss for jointly training on continuous and categorical targets.
PiperOrigin-RevId: 847793593
1 parent cb21fdf commit 49efa3d

File tree

1 file changed

+85
-12
lines changed

1 file changed

+85
-12
lines changed

disentangled_rnns/library/rnn_utils.py

Lines changed: 85 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,67 @@ def normalized_likelihood(
509509
return normlik
510510

511511

512+
def avg_nll_and_log_mse(
513+
ys: np.ndarray,
514+
y_hats: np.ndarray,
515+
likelihood_weight: float = 0.5,
516+
n_categorical_targets: int = 2,
517+
n_continuous_targets: int = 1,
518+
) -> float:
519+
"""Computes a scale-normalized weighted sum of avg NLL and log(MSE).
520+
521+
This loss is suitable for joint training on heterogeneous objectives
522+
(categorical + continuous) by putting both losses on the same log scale.
523+
The log(MSE) formulation ensures that both components have comparable
524+
magnitudes during optimization.
525+
526+
By convention, for mixed targets:
527+
- `ys` contains categorical targets at index 0 and continuous targets at
528+
indices [1, 1 + n_continuous_targets).
529+
- `y_hats` contains logits for categorical targets at indices
530+
[0, n_categorical_targets) and predictions for continuous targets at indices
531+
[n_categorical_targets, n_categorical_targets + n_continuous_targets).
532+
533+
Args:
534+
ys: Ground truth targets.
535+
y_hats: Network predictions (logits for categorical, values for continuous).
536+
likelihood_weight: The weight for the average NLL. The log(MSE) is weighted
537+
by (1 - likelihood_weight). Default is 0.5 for equal weighting.
538+
n_categorical_targets: The number of output logits for the categorical
539+
target.
540+
n_continuous_targets: The number of continuous targets.
541+
542+
Returns:
543+
The weighted sum of average NLL and log(MSE + eps).
544+
"""
545+
categorical_y_hats = y_hats[:, :, 0:n_categorical_targets]
546+
categorical_ys = ys[:, :, 0:1]
547+
548+
mask = jnp.logical_not(categorical_ys < 0)
549+
continuous_y_hats = y_hats[
550+
:, :, n_categorical_targets : n_categorical_targets + n_continuous_targets
551+
]
552+
continuous_ys = ys[:, :, 1 : 1 + n_continuous_targets]
553+
continuous_ys = jnp.where(mask, continuous_ys, jnp.nan)
554+
555+
nll, n_unmasked_samples = categorical_neg_log_likelihood(
556+
categorical_ys, categorical_y_hats
557+
)
558+
avg_nll = nll / n_unmasked_samples
559+
560+
mse_val = mse(continuous_ys, continuous_y_hats)
561+
562+
# This is a trick to scale the gradients from the mse as:
563+
# derivative(log(1+mse)) = 1 / (1 + mse) * derivative(1+mse)
564+
# So early in the training when mse is large, the gradient is damped, later
565+
# when mse is small, the gradient is almost the same as the original gradient.
566+
log_mse_val = jnp.log(1.0 + mse_val)
567+
568+
# Likelihood weight should ideally be set to 0.5, but can be used as a toggle
569+
# to train on just one objective at a time (e.g. likelihood or mse).
570+
return avg_nll * likelihood_weight + log_mse_val * (1 - likelihood_weight)
571+
572+
512573
@jax.jit
513574
def compute_penalty(
514575
targets: np.ndarray, outputs: np.ndarray
@@ -594,7 +655,8 @@ def train_network(
594655
{'penalty_scale': 0.1, 'likelihood_weight': 0.8}) or a single float for
595656
simpler losses.
596657
loss: The loss function to use. Options are 'mse', 'penalized_mse',
597-
'categorical', 'penalized_categorical', 'hybrid', 'penalized_hybrid'.
658+
'categorical', 'penalized_categorical', 'hybrid', 'penalized_hybrid',
659+
'penalized_log_hybrid'.
598660
log_losses_every: How many training steps between each time we check for
599661
errors and log the loss.
600662
do_plot: Boolean that controls whether a learning curve is plotted.
@@ -719,24 +781,35 @@ def hybrid_loss(
719781
def penalized_hybrid_loss(
720782
params, xs, ys, random_key, loss_param=loss_param_dict
721783
) -> float:
722-
"""A hybrid loss with a penalty."""
784+
"""A hybrid loss with a penalty.
785+
786+
Useful for jointly training on categorical and continuous targets. Uses a
787+
log of MSE loss for the continuous targets, so that the loss is similar
788+
units as the categorical loss.
789+
790+
Args:
791+
params: The network parameters.
792+
xs: The input data.
793+
ys: The target data.
794+
random_key: A JAX random key.
795+
loss_param: Parameters for the loss function, potentially including
796+
'penalty_scale' and 'likelihood_weight'.
797+
798+
Returns:
799+
The computed penalized hybrid loss.
800+
"""
723801

724802
penalty_scale = get_loss_param(loss_param, 'penalty_scale', 1.0)
725803
model_output = model.apply(params, random_key, xs)
726804

727-
# model_output has the continuous and categorical targets first followed by
728-
# the penalty. The likelihood_and_sse functions handles
729-
# ignoring the penalty, hence we don't need to do anything special here.
730805
y_hats = model_output
731-
likelihood_weight = get_loss_param(loss_param, 'likelihood_weight', 1.0)
732-
supervised_loss = likelihood_and_sse(
806+
likelihood_weight = get_loss_param(loss_param, 'likelihood_weight', 0.5)
807+
supervised_loss = avg_nll_and_log_mse(
733808
ys, y_hats, likelihood_weight=likelihood_weight
734809
)
735-
# Supervised loss here is a sum not an average, so use the raw penalty
736-
# without dividing by n_unmasked_samples
737-
# TODO(siddhantjain): Evaluate whether we should use averaging here too.
738-
penalty, _ = compute_penalty(ys, y_hats)
739-
loss = supervised_loss + penalty_scale * penalty
810+
penalty, n_unmasked_samples = compute_penalty(ys, y_hats)
811+
avg_penalty = penalty / n_unmasked_samples
812+
loss = supervised_loss + penalty_scale * avg_penalty
740813
return loss
741814

742815
losses = {

0 commit comments

Comments
 (0)