From acff7160f0f5e43c7189978a4978010a26fc9c43 Mon Sep 17 00:00:00 2001 From: Siddhant Jain Date: Fri, 9 Jan 2026 16:34:42 -0800 Subject: [PATCH] Add a new loss for jointly training on continuous and categorical targets. PiperOrigin-RevId: 854385458 --- disentangled_rnns/library/rnn_utils.py | 99 ++++++++++++++++++++++---- 1 file changed, 87 insertions(+), 12 deletions(-) diff --git a/disentangled_rnns/library/rnn_utils.py b/disentangled_rnns/library/rnn_utils.py index e0b36ad..1863685 100644 --- a/disentangled_rnns/library/rnn_utils.py +++ b/disentangled_rnns/library/rnn_utils.py @@ -509,6 +509,69 @@ def normalized_likelihood( return normlik +def avg_nll_and_log_mse( + ys: np.ndarray, + y_hats: np.ndarray, + likelihood_weight: float = 0.5, + n_categorical_targets: int = 2, + n_continuous_targets: int = 1, +) -> float: + """Computes a scale-normalized weighted sum of avg NLL and log(MSE). + + This loss is suitable for joint training on heterogeneous objectives + (categorical + continuous) by putting both losses on the same log scale. + The log(MSE) formulation helps keep the gradients from the MSE term from + dominating the overall gradient. + + By convention, for mixed targets: + - `ys` contains categorical targets at index 0 and continuous targets at + indices [1, 1 + n_continuous_targets). + - `y_hats` contains logits for categorical targets at indices + [0, n_categorical_targets) and predictions for continuous targets at indices + [n_categorical_targets, n_categorical_targets + n_continuous_targets). + + Args: + ys: Ground truth targets. + y_hats: Network predictions (logits for categorical, values for continuous). + likelihood_weight: The weight for the average NLL. The log(MSE) is weighted + by (1 - likelihood_weight). Default is 0.5 for equal weighting. + n_categorical_targets: The number of output logits for the categorical + target. + n_continuous_targets: The number of continuous targets. + + Returns: + The weighted sum of average NLL and log(MSE + eps). + """ + categorical_y_hats = y_hats[:, :, 0:n_categorical_targets] + categorical_ys = ys[:, :, 0:1] + + mask = jnp.logical_not(categorical_ys < 0) + continuous_y_hats = y_hats[ + :, :, n_categorical_targets : n_categorical_targets + n_continuous_targets + ] + continuous_ys = ys[:, :, 1 : 1 + n_continuous_targets] + continuous_ys = jnp.where(mask, continuous_ys, jnp.nan) + + nll, n_unmasked_samples = categorical_neg_log_likelihood( + categorical_ys, categorical_y_hats + ) + avg_nll = nll / n_unmasked_samples + + mse_val = mse(continuous_ys, continuous_y_hats) + + # This is a trick to scale the gradients from the mse as: + # derivative(log(1+mse)) = 1 / (1 + mse) * derivative(1+mse) + # So early in the training when mse is large, the gradient is damped, later + # when mse is small, the gradient is almost the same as the original gradient. + # Note that log(mse) is follows the same monotonic trend as mse, so this + # transform will lead to a similar optimization as the original mse. + log_mse_val = jnp.log(1.0 + mse_val) + + # Likelihood weight should ideally be set to 0.5, but can be used as a toggle + # to train on just one objective at a time (e.g. likelihood or mse). + return avg_nll * likelihood_weight + log_mse_val * (1 - likelihood_weight) + + @jax.jit def compute_penalty( targets: np.ndarray, outputs: np.ndarray @@ -594,7 +657,8 @@ def train_network( {'penalty_scale': 0.1, 'likelihood_weight': 0.8}) or a single float for simpler losses. loss: The loss function to use. Options are 'mse', 'penalized_mse', - 'categorical', 'penalized_categorical', 'hybrid', 'penalized_hybrid'. + 'categorical', 'penalized_categorical', 'hybrid', 'penalized_hybrid', + 'penalized_log_hybrid'. log_losses_every: How many training steps between each time we check for errors and log the loss. do_plot: Boolean that controls whether a learning curve is plotted. @@ -719,24 +783,35 @@ def hybrid_loss( def penalized_hybrid_loss( params, xs, ys, random_key, loss_param=loss_param_dict ) -> float: - """A hybrid loss with a penalty.""" + """A hybrid loss with a penalty. + + Useful for jointly training on categorical and continuous targets. Uses a + log of MSE loss for the continuous targets, so that the loss is similar + units as the categorical loss. + + Args: + params: The network parameters. + xs: The input data. + ys: The target data. + random_key: A JAX random key. + loss_param: Parameters for the loss function, potentially including + 'penalty_scale' and 'likelihood_weight'. + + Returns: + The computed penalized hybrid loss. + """ penalty_scale = get_loss_param(loss_param, 'penalty_scale', 1.0) model_output = model.apply(params, random_key, xs) - # model_output has the continuous and categorical targets first followed by - # the penalty. The likelihood_and_sse functions handles - # ignoring the penalty, hence we don't need to do anything special here. y_hats = model_output - likelihood_weight = get_loss_param(loss_param, 'likelihood_weight', 1.0) - supervised_loss = likelihood_and_sse( + likelihood_weight = get_loss_param(loss_param, 'likelihood_weight', 0.5) + supervised_loss = avg_nll_and_log_mse( ys, y_hats, likelihood_weight=likelihood_weight ) - # Supervised loss here is a sum not an average, so use the raw penalty - # without dividing by n_unmasked_samples - # TODO(siddhantjain): Evaluate whether we should use averaging here too. - penalty, _ = compute_penalty(ys, y_hats) - loss = supervised_loss + penalty_scale * penalty + penalty, n_unmasked_samples = compute_penalty(ys, y_hats) + avg_penalty = penalty / n_unmasked_samples + loss = supervised_loss + penalty_scale * avg_penalty return loss losses = {