-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Long story short, an expression of the form loss + 0. * nan evaluates to nan in non-jitted and GPU-jitted code (as expected), but evaluates to 0 on CPU-jitted code, BUT, only if the jitted function returns some specific structures.
Here's an example colaboratory:
https://colab.research.google.com/drive/1DMxoij453nBeaGx-r1IuLC88CeksR0q_?usp=sharing
And here's the code.
import jax.numpy as jnp
import jax
import numpy as np
def loss_fn(values, mask):
# introduce some nans manually according to a mask
values_masked = jnp.where(mask, jnp.nan, values)
losses = {}
losses["loss_1"] = jnp.mean(values)
losses["loss_2"] = jnp.mean(values_masked)
losses["total_loss"] = losses["loss_1"] + losses["loss_2"] * 0.
return losses
# If we just return the total loss instead, we get nan in ALL cases.
# return losses["total_loss"]
loss_fn_jitted_cpu = jax.jit(loss_fn, backend="cpu")
loss_fn_jitted_gpu = jax.jit(loss_fn, backend="gpu")
values = np.array([1., 1.])
mask = np.array([False, True])
# Note how on CPU the total loss is 1. and not nan!!!! somehow it
# trims the `losses["loss_2"] * 0.` term, or resolves it to 0.
print(loss_fn(values, mask)) # {'loss_1': DeviceArray(1., dtype=float32), 'loss_2': DeviceArray(nan, dtype=float32), 'total_loss': DeviceArray(nan, dtype=float32)}
print(loss_fn_jitted_gpu(values, mask)) # {'loss_1': DeviceArray(1., dtype=float32), 'loss_2': DeviceArray(nan, dtype=float32), 'total_loss': DeviceArray(nan, dtype=float32)}
print(loss_fn_jitted_cpu(values, mask)) # {'loss_1': DeviceArray(1., dtype=float32), 'loss_2': DeviceArray(nan, dtype=float32), 'total_loss': DeviceArray(1., dtype=float32)}
# If we just return the total loss instead, then we get nan in all 3 functions (as it should be)
# print(loss_fn(values, mask)) # nan
# print(loss_fn_jitted_gpu(values, mask)) # nan
# print(loss_fn_jitted_cpu(values, mask)) # nan
VolodyaCO
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working