Skip to content

Commit 0c42949

Browse files
siddhantjaincopybara-github
authored andcommitted
Add a new loss for jointly training on continuous and categorical targets.
PiperOrigin-RevId: 847793593
1 parent 8d8a20b commit 0c42949

File tree

1 file changed

+87
-12
lines changed

1 file changed

+87
-12
lines changed

disentangled_rnns/library/rnn_utils.py

Lines changed: 87 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,69 @@ def normalized_likelihood(
509509
return normlik
510510

511511

512+
def avg_nll_and_log_mse(
513+
ys: np.ndarray,
514+
y_hats: np.ndarray,
515+
likelihood_weight: float = 0.5,
516+
n_categorical_targets: int = 2,
517+
n_continuous_targets: int = 1,
518+
) -> float:
519+
"""Computes a scale-normalized weighted sum of avg NLL and log(MSE).
520+
521+
This loss is suitable for joint training on heterogeneous objectives
522+
(categorical + continuous) by putting both losses on the same log scale.
523+
The log(MSE) formulation helps keep the gradients from the MSE term from
524+
dominating the overall gradient.
525+
526+
By convention, for mixed targets:
527+
- `ys` contains categorical targets at index 0 and continuous targets at
528+
indices [1, 1 + n_continuous_targets).
529+
- `y_hats` contains logits for categorical targets at indices
530+
[0, n_categorical_targets) and predictions for continuous targets at indices
531+
[n_categorical_targets, n_categorical_targets + n_continuous_targets).
532+
533+
Args:
534+
ys: Ground truth targets.
535+
y_hats: Network predictions (logits for categorical, values for continuous).
536+
likelihood_weight: The weight for the average NLL. The log(MSE) is weighted
537+
by (1 - likelihood_weight). Default is 0.5 for equal weighting.
538+
n_categorical_targets: The number of output logits for the categorical
539+
target.
540+
n_continuous_targets: The number of continuous targets.
541+
542+
Returns:
543+
The weighted sum of average NLL and log(MSE + eps).
544+
"""
545+
categorical_y_hats = y_hats[:, :, 0:n_categorical_targets]
546+
categorical_ys = ys[:, :, 0:1]
547+
548+
mask = jnp.logical_not(categorical_ys < 0)
549+
continuous_y_hats = y_hats[
550+
:, :, n_categorical_targets : n_categorical_targets + n_continuous_targets
551+
]
552+
continuous_ys = ys[:, :, 1 : 1 + n_continuous_targets]
553+
continuous_ys = jnp.where(mask, continuous_ys, jnp.nan)
554+
555+
nll, n_unmasked_samples = categorical_neg_log_likelihood(
556+
categorical_ys, categorical_y_hats
557+
)
558+
avg_nll = nll / n_unmasked_samples
559+
560+
mse_val = mse(continuous_ys, continuous_y_hats)
561+
562+
# This is a trick to scale the gradients from the mse as:
563+
# derivative(log(1+mse)) = 1 / (1 + mse) * derivative(1+mse)
564+
# So early in the training when mse is large, the gradient is damped, later
565+
# when mse is small, the gradient is almost the same as the original gradient.
566+
# Note that log(mse) is follows the same monotonic trend as mse, so this
567+
# transform will lead to a similar optimization as the original mse.
568+
log_mse_val = jnp.log(1.0 + mse_val)
569+
570+
# Likelihood weight should ideally be set to 0.5, but can be used as a toggle
571+
# to train on just one objective at a time (e.g. likelihood or mse).
572+
return avg_nll * likelihood_weight + log_mse_val * (1 - likelihood_weight)
573+
574+
512575
@jax.jit
513576
def compute_penalty(
514577
targets: np.ndarray, outputs: np.ndarray
@@ -594,7 +657,8 @@ def train_network(
594657
{'penalty_scale': 0.1, 'likelihood_weight': 0.8}) or a single float for
595658
simpler losses.
596659
loss: The loss function to use. Options are 'mse', 'penalized_mse',
597-
'categorical', 'penalized_categorical', 'hybrid', 'penalized_hybrid'.
660+
'categorical', 'penalized_categorical', 'hybrid', 'penalized_hybrid',
661+
'penalized_log_hybrid'.
598662
log_losses_every: How many training steps between each time we check for
599663
errors and log the loss.
600664
do_plot: Boolean that controls whether a learning curve is plotted.
@@ -719,24 +783,35 @@ def hybrid_loss(
719783
def penalized_hybrid_loss(
720784
params, xs, ys, random_key, loss_param=loss_param_dict
721785
) -> float:
722-
"""A hybrid loss with a penalty."""
786+
"""A hybrid loss with a penalty.
787+
788+
Useful for jointly training on categorical and continuous targets. Uses a
789+
log of MSE loss for the continuous targets, so that the loss is similar
790+
units as the categorical loss.
791+
792+
Args:
793+
params: The network parameters.
794+
xs: The input data.
795+
ys: The target data.
796+
random_key: A JAX random key.
797+
loss_param: Parameters for the loss function, potentially including
798+
'penalty_scale' and 'likelihood_weight'.
799+
800+
Returns:
801+
The computed penalized hybrid loss.
802+
"""
723803

724804
penalty_scale = get_loss_param(loss_param, 'penalty_scale', 1.0)
725805
model_output = model.apply(params, random_key, xs)
726806

727-
# model_output has the continuous and categorical targets first followed by
728-
# the penalty. The likelihood_and_sse functions handles
729-
# ignoring the penalty, hence we don't need to do anything special here.
730807
y_hats = model_output
731-
likelihood_weight = get_loss_param(loss_param, 'likelihood_weight', 1.0)
732-
supervised_loss = likelihood_and_sse(
808+
likelihood_weight = get_loss_param(loss_param, 'likelihood_weight', 0.5)
809+
supervised_loss = avg_nll_and_log_mse(
733810
ys, y_hats, likelihood_weight=likelihood_weight
734811
)
735-
# Supervised loss here is a sum not an average, so use the raw penalty
736-
# without dividing by n_unmasked_samples
737-
# TODO(siddhantjain): Evaluate whether we should use averaging here too.
738-
penalty, _ = compute_penalty(ys, y_hats)
739-
loss = supervised_loss + penalty_scale * penalty
812+
penalty, n_unmasked_samples = compute_penalty(ys, y_hats)
813+
avg_penalty = penalty / n_unmasked_samples
814+
loss = supervised_loss + penalty_scale * avg_penalty
740815
return loss
741816

742817
losses = {

0 commit comments

Comments
 (0)