@@ -509,6 +509,67 @@ def normalized_likelihood(
509509 return normlik
510510
511511
512+ def avg_nll_and_log_mse (
513+ ys : np .ndarray ,
514+ y_hats : np .ndarray ,
515+ likelihood_weight : float = 0.5 ,
516+ n_categorical_targets : int = 2 ,
517+ n_continuous_targets : int = 1 ,
518+ ) -> float :
519+ """Computes a scale-normalized weighted sum of avg NLL and log(MSE).
520+
521+ This loss is suitable for joint training on heterogeneous objectives
522+ (categorical + continuous) by putting both losses on the same log scale.
523+ The log(MSE) formulation ensures that both components have comparable
524+ magnitudes during optimization.
525+
526+ By convention, for mixed targets:
527+ - `ys` contains categorical targets at index 0 and continuous targets at
528+ indices [1, 1 + n_continuous_targets).
529+ - `y_hats` contains logits for categorical targets at indices
530+ [0, n_categorical_targets) and predictions for continuous targets at indices
531+ [n_categorical_targets, n_categorical_targets + n_continuous_targets).
532+
533+ Args:
534+ ys: Ground truth targets.
535+ y_hats: Network predictions (logits for categorical, values for continuous).
536+ likelihood_weight: The weight for the average NLL. The log(MSE) is weighted
537+ by (1 - likelihood_weight). Default is 0.5 for equal weighting.
538+ n_categorical_targets: The number of output logits for the categorical
539+ target.
540+ n_continuous_targets: The number of continuous targets.
541+
542+ Returns:
543+ The weighted sum of average NLL and log(MSE + eps).
544+ """
545+ categorical_y_hats = y_hats [:, :, 0 :n_categorical_targets ]
546+ categorical_ys = ys [:, :, 0 :1 ]
547+
548+ mask = jnp .logical_not (categorical_ys < 0 )
549+ continuous_y_hats = y_hats [
550+ :, :, n_categorical_targets : n_categorical_targets + n_continuous_targets
551+ ]
552+ continuous_ys = ys [:, :, 1 : 1 + n_continuous_targets ]
553+ continuous_ys = jnp .where (mask , continuous_ys , jnp .nan )
554+
555+ nll , n_unmasked_samples = categorical_neg_log_likelihood (
556+ categorical_ys , categorical_y_hats
557+ )
558+ avg_nll = nll / n_unmasked_samples
559+
560+ mse_val = mse (continuous_ys , continuous_y_hats )
561+
562+ # This is a trick to scale the gradients from the mse as:
563+ # derivative(log(1+mse)) = 1 / (1 + mse) * derivative(1+mse)
564+ # So early in the training when mse is large, the gradient is damped, later
565+ # when mse is small, the gradient is almost the same as the original gradient.
566+ log_mse_val = jnp .log (1.0 + mse_val )
567+
568+ # Likelihood weight should ideally be set to 0.5, but can be used as a toggle
569+ # to train on just one objective at a time (e.g. likelihood or mse).
570+ return avg_nll * likelihood_weight + log_mse_val * (1 - likelihood_weight )
571+
572+
512573@jax .jit
513574def compute_penalty (
514575 targets : np .ndarray , outputs : np .ndarray
@@ -594,7 +655,8 @@ def train_network(
594655 {'penalty_scale': 0.1, 'likelihood_weight': 0.8}) or a single float for
595656 simpler losses.
596657 loss: The loss function to use. Options are 'mse', 'penalized_mse',
597- 'categorical', 'penalized_categorical', 'hybrid', 'penalized_hybrid'.
658+ 'categorical', 'penalized_categorical', 'hybrid', 'penalized_hybrid',
659+ 'penalized_log_hybrid'.
598660 log_losses_every: How many training steps between each time we check for
599661 errors and log the loss.
600662 do_plot: Boolean that controls whether a learning curve is plotted.
@@ -719,24 +781,35 @@ def hybrid_loss(
719781 def penalized_hybrid_loss (
720782 params , xs , ys , random_key , loss_param = loss_param_dict
721783 ) -> float :
722- """A hybrid loss with a penalty."""
784+ """A hybrid loss with a penalty.
785+
786+ Useful for jointly training on categorical and continuous targets. Uses a
787+ log of MSE loss for the continuous targets, so that the loss is similar
788+ units as the categorical loss.
789+
790+ Args:
791+ params: The network parameters.
792+ xs: The input data.
793+ ys: The target data.
794+ random_key: A JAX random key.
795+ loss_param: Parameters for the loss function, potentially including
796+ 'penalty_scale' and 'likelihood_weight'.
797+
798+ Returns:
799+ The computed penalized hybrid loss.
800+ """
723801
724802 penalty_scale = get_loss_param (loss_param , 'penalty_scale' , 1.0 )
725803 model_output = model .apply (params , random_key , xs )
726804
727- # model_output has the continuous and categorical targets first followed by
728- # the penalty. The likelihood_and_sse functions handles
729- # ignoring the penalty, hence we don't need to do anything special here.
730805 y_hats = model_output
731- likelihood_weight = get_loss_param (loss_param , 'likelihood_weight' , 1.0 )
732- supervised_loss = likelihood_and_sse (
806+ likelihood_weight = get_loss_param (loss_param , 'likelihood_weight' , 0.5 )
807+ supervised_loss = avg_nll_and_log_mse (
733808 ys , y_hats , likelihood_weight = likelihood_weight
734809 )
735- # Supervised loss here is a sum not an average, so use the raw penalty
736- # without dividing by n_unmasked_samples
737- # TODO(siddhantjain): Evaluate whether we should use averaging here too.
738- penalty , _ = compute_penalty (ys , y_hats )
739- loss = supervised_loss + penalty_scale * penalty
810+ penalty , n_unmasked_samples = compute_penalty (ys , y_hats )
811+ avg_penalty = penalty / n_unmasked_samples
812+ loss = supervised_loss + penalty_scale * avg_penalty
740813 return loss
741814
742815 losses = {
0 commit comments