@@ -509,6 +509,69 @@ def normalized_likelihood(
509509 return normlik
510510
511511
512+ def avg_nll_and_log_mse (
513+ ys : np .ndarray ,
514+ y_hats : np .ndarray ,
515+ likelihood_weight : float = 0.5 ,
516+ n_categorical_targets : int = 2 ,
517+ n_continuous_targets : int = 1 ,
518+ ) -> float :
519+ """Computes a scale-normalized weighted sum of avg NLL and log(MSE).
520+
521+ This loss is suitable for joint training on heterogeneous objectives
522+ (categorical + continuous) by putting both losses on the same log scale.
523+ The log(MSE) formulation helps keep the gradients from the MSE term from
524+ dominating the overall gradient.
525+
526+ By convention, for mixed targets:
527+ - `ys` contains categorical targets at index 0 and continuous targets at
528+ indices [1, 1 + n_continuous_targets).
529+ - `y_hats` contains logits for categorical targets at indices
530+ [0, n_categorical_targets) and predictions for continuous targets at indices
531+ [n_categorical_targets, n_categorical_targets + n_continuous_targets).
532+
533+ Args:
534+ ys: Ground truth targets.
535+ y_hats: Network predictions (logits for categorical, values for continuous).
536+ likelihood_weight: The weight for the average NLL. The log(MSE) is weighted
537+ by (1 - likelihood_weight). Default is 0.5 for equal weighting.
538+ n_categorical_targets: The number of output logits for the categorical
539+ target.
540+ n_continuous_targets: The number of continuous targets.
541+
542+ Returns:
543+ The weighted sum of average NLL and log(MSE + eps).
544+ """
545+ categorical_y_hats = y_hats [:, :, 0 :n_categorical_targets ]
546+ categorical_ys = ys [:, :, 0 :1 ]
547+
548+ mask = jnp .logical_not (categorical_ys < 0 )
549+ continuous_y_hats = y_hats [
550+ :, :, n_categorical_targets : n_categorical_targets + n_continuous_targets
551+ ]
552+ continuous_ys = ys [:, :, 1 : 1 + n_continuous_targets ]
553+ continuous_ys = jnp .where (mask , continuous_ys , jnp .nan )
554+
555+ nll , n_unmasked_samples = categorical_neg_log_likelihood (
556+ categorical_ys , categorical_y_hats
557+ )
558+ avg_nll = nll / n_unmasked_samples
559+
560+ mse_val = mse (continuous_ys , continuous_y_hats )
561+
562+ # This is a trick to scale the gradients from the mse as:
563+ # derivative(log(1+mse)) = 1 / (1 + mse) * derivative(1+mse)
564+ # So early in the training when mse is large, the gradient is damped, later
565+ # when mse is small, the gradient is almost the same as the original gradient.
566+ # Note that log(mse) is follows the same monotonic trend as mse, so this
567+ # transform will lead to a similar optimization as the original mse.
568+ log_mse_val = jnp .log (1.0 + mse_val )
569+
570+ # Likelihood weight should ideally be set to 0.5, but can be used as a toggle
571+ # to train on just one objective at a time (e.g. likelihood or mse).
572+ return avg_nll * likelihood_weight + log_mse_val * (1 - likelihood_weight )
573+
574+
512575@jax .jit
513576def compute_penalty (
514577 targets : np .ndarray , outputs : np .ndarray
@@ -594,7 +657,8 @@ def train_network(
594657 {'penalty_scale': 0.1, 'likelihood_weight': 0.8}) or a single float for
595658 simpler losses.
596659 loss: The loss function to use. Options are 'mse', 'penalized_mse',
597- 'categorical', 'penalized_categorical', 'hybrid', 'penalized_hybrid'.
660+ 'categorical', 'penalized_categorical', 'hybrid', 'penalized_hybrid',
661+ 'penalized_log_hybrid'.
598662 log_losses_every: How many training steps between each time we check for
599663 errors and log the loss.
600664 do_plot: Boolean that controls whether a learning curve is plotted.
@@ -719,24 +783,35 @@ def hybrid_loss(
719783 def penalized_hybrid_loss (
720784 params , xs , ys , random_key , loss_param = loss_param_dict
721785 ) -> float :
722- """A hybrid loss with a penalty."""
786+ """A hybrid loss with a penalty.
787+
788+ Useful for jointly training on categorical and continuous targets. Uses a
789+ log of MSE loss for the continuous targets, so that the loss is similar
790+ units as the categorical loss.
791+
792+ Args:
793+ params: The network parameters.
794+ xs: The input data.
795+ ys: The target data.
796+ random_key: A JAX random key.
797+ loss_param: Parameters for the loss function, potentially including
798+ 'penalty_scale' and 'likelihood_weight'.
799+
800+ Returns:
801+ The computed penalized hybrid loss.
802+ """
723803
724804 penalty_scale = get_loss_param (loss_param , 'penalty_scale' , 1.0 )
725805 model_output = model .apply (params , random_key , xs )
726806
727- # model_output has the continuous and categorical targets first followed by
728- # the penalty. The likelihood_and_sse functions handles
729- # ignoring the penalty, hence we don't need to do anything special here.
730807 y_hats = model_output
731- likelihood_weight = get_loss_param (loss_param , 'likelihood_weight' , 1.0 )
732- supervised_loss = likelihood_and_sse (
808+ likelihood_weight = get_loss_param (loss_param , 'likelihood_weight' , 0.5 )
809+ supervised_loss = avg_nll_and_log_mse (
733810 ys , y_hats , likelihood_weight = likelihood_weight
734811 )
735- # Supervised loss here is a sum not an average, so use the raw penalty
736- # without dividing by n_unmasked_samples
737- # TODO(siddhantjain): Evaluate whether we should use averaging here too.
738- penalty , _ = compute_penalty (ys , y_hats )
739- loss = supervised_loss + penalty_scale * penalty
812+ penalty , n_unmasked_samples = compute_penalty (ys , y_hats )
813+ avg_penalty = penalty / n_unmasked_samples
814+ loss = supervised_loss + penalty_scale * avg_penalty
740815 return loss
741816
742817 losses = {
0 commit comments