Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 87 additions & 12 deletions disentangled_rnns/library/rnn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,69 @@ def normalized_likelihood(
return normlik


def avg_nll_and_log_mse(
ys: np.ndarray,
y_hats: np.ndarray,
likelihood_weight: float = 0.5,
n_categorical_targets: int = 2,
n_continuous_targets: int = 1,
) -> float:
"""Computes a scale-normalized weighted sum of avg NLL and log(MSE).

This loss is suitable for joint training on heterogeneous objectives
(categorical + continuous) by putting both losses on the same log scale.
The log(MSE) formulation helps keep the gradients from the MSE term from
dominating the overall gradient.

By convention, for mixed targets:
- `ys` contains categorical targets at index 0 and continuous targets at
indices [1, 1 + n_continuous_targets).
- `y_hats` contains logits for categorical targets at indices
[0, n_categorical_targets) and predictions for continuous targets at indices
[n_categorical_targets, n_categorical_targets + n_continuous_targets).

Args:
ys: Ground truth targets.
y_hats: Network predictions (logits for categorical, values for continuous).
likelihood_weight: The weight for the average NLL. The log(MSE) is weighted
by (1 - likelihood_weight). Default is 0.5 for equal weighting.
n_categorical_targets: The number of output logits for the categorical
target.
n_continuous_targets: The number of continuous targets.

Returns:
The weighted sum of average NLL and log(MSE + eps).
"""
categorical_y_hats = y_hats[:, :, 0:n_categorical_targets]
categorical_ys = ys[:, :, 0:1]

mask = jnp.logical_not(categorical_ys < 0)
continuous_y_hats = y_hats[
:, :, n_categorical_targets : n_categorical_targets + n_continuous_targets
]
continuous_ys = ys[:, :, 1 : 1 + n_continuous_targets]
continuous_ys = jnp.where(mask, continuous_ys, jnp.nan)

nll, n_unmasked_samples = categorical_neg_log_likelihood(
categorical_ys, categorical_y_hats
)
avg_nll = nll / n_unmasked_samples

mse_val = mse(continuous_ys, continuous_y_hats)

# This is a trick to scale the gradients from the mse as:
# derivative(log(1+mse)) = 1 / (1 + mse) * derivative(1+mse)
# So early in the training when mse is large, the gradient is damped, later
# when mse is small, the gradient is almost the same as the original gradient.
# Note that log(mse) is follows the same monotonic trend as mse, so this
# transform will lead to a similar optimization as the original mse.
log_mse_val = jnp.log(1.0 + mse_val)

# Likelihood weight should ideally be set to 0.5, but can be used as a toggle
# to train on just one objective at a time (e.g. likelihood or mse).
return avg_nll * likelihood_weight + log_mse_val * (1 - likelihood_weight)


@jax.jit
def compute_penalty(
targets: np.ndarray, outputs: np.ndarray
Expand Down Expand Up @@ -594,7 +657,8 @@ def train_network(
{'penalty_scale': 0.1, 'likelihood_weight': 0.8}) or a single float for
simpler losses.
loss: The loss function to use. Options are 'mse', 'penalized_mse',
'categorical', 'penalized_categorical', 'hybrid', 'penalized_hybrid'.
'categorical', 'penalized_categorical', 'hybrid', 'penalized_hybrid',
'penalized_log_hybrid'.
log_losses_every: How many training steps between each time we check for
errors and log the loss.
do_plot: Boolean that controls whether a learning curve is plotted.
Expand Down Expand Up @@ -719,24 +783,35 @@ def hybrid_loss(
def penalized_hybrid_loss(
params, xs, ys, random_key, loss_param=loss_param_dict
) -> float:
"""A hybrid loss with a penalty."""
"""A hybrid loss with a penalty.

Useful for jointly training on categorical and continuous targets. Uses a
log of MSE loss for the continuous targets, so that the loss is similar
units as the categorical loss.

Args:
params: The network parameters.
xs: The input data.
ys: The target data.
random_key: A JAX random key.
loss_param: Parameters for the loss function, potentially including
'penalty_scale' and 'likelihood_weight'.

Returns:
The computed penalized hybrid loss.
"""

penalty_scale = get_loss_param(loss_param, 'penalty_scale', 1.0)
model_output = model.apply(params, random_key, xs)

# model_output has the continuous and categorical targets first followed by
# the penalty. The likelihood_and_sse functions handles
# ignoring the penalty, hence we don't need to do anything special here.
y_hats = model_output
likelihood_weight = get_loss_param(loss_param, 'likelihood_weight', 1.0)
supervised_loss = likelihood_and_sse(
likelihood_weight = get_loss_param(loss_param, 'likelihood_weight', 0.5)
supervised_loss = avg_nll_and_log_mse(
ys, y_hats, likelihood_weight=likelihood_weight
)
# Supervised loss here is a sum not an average, so use the raw penalty
# without dividing by n_unmasked_samples
# TODO(siddhantjain): Evaluate whether we should use averaging here too.
penalty, _ = compute_penalty(ys, y_hats)
loss = supervised_loss + penalty_scale * penalty
penalty, n_unmasked_samples = compute_penalty(ys, y_hats)
avg_penalty = penalty / n_unmasked_samples
loss = supervised_loss + penalty_scale * avg_penalty
return loss

losses = {
Expand Down