Skip to content

Inconsistent nan behavior across devices and jit vs non-jitΒ #4780

@alvarosg

Description

@alvarosg

Long story short, an expression of the form loss + 0. * nan evaluates to nan in non-jitted and GPU-jitted code (as expected), but evaluates to 0 on CPU-jitted code, BUT, only if the jitted function returns some specific structures.

Here's an example colaboratory:

https://colab.research.google.com/drive/1DMxoij453nBeaGx-r1IuLC88CeksR0q_?usp=sharing

And here's the code.

import jax.numpy as jnp
import jax
import numpy as np

def loss_fn(values, mask):
  # introduce some nans manually according to a mask
  values_masked = jnp.where(mask, jnp.nan, values)
  losses = {}
  losses["loss_1"] = jnp.mean(values)
  losses["loss_2"] = jnp.mean(values_masked)
  losses["total_loss"] = losses["loss_1"] + losses["loss_2"] * 0.
  return losses

  # If we just return the total loss instead, we get nan in ALL cases.
  # return losses["total_loss"]

loss_fn_jitted_cpu = jax.jit(loss_fn, backend="cpu")
loss_fn_jitted_gpu = jax.jit(loss_fn, backend="gpu")

values = np.array([1., 1.])
mask = np.array([False, True])

# Note how on CPU the total loss is 1. and not nan!!!! somehow it
# trims the `losses["loss_2"] * 0.` term, or resolves it to 0.
print(loss_fn(values, mask))             # {'loss_1': DeviceArray(1., dtype=float32), 'loss_2': DeviceArray(nan, dtype=float32), 'total_loss': DeviceArray(nan, dtype=float32)}
print(loss_fn_jitted_gpu(values, mask))  # {'loss_1': DeviceArray(1., dtype=float32), 'loss_2': DeviceArray(nan, dtype=float32), 'total_loss': DeviceArray(nan, dtype=float32)}
print(loss_fn_jitted_cpu(values, mask))  # {'loss_1': DeviceArray(1., dtype=float32), 'loss_2': DeviceArray(nan, dtype=float32), 'total_loss': DeviceArray(1., dtype=float32)}

# If we just return the total loss instead, then we get nan in all 3 functions (as it should be)
# print(loss_fn(values, mask))             # nan
# print(loss_fn_jitted_gpu(values, mask))  # nan
# print(loss_fn_jitted_cpu(values, mask))  # nan

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions