-
|
Hi JAX team, I want to confirm whether this is expected behavior and how to reason about it. SummaryFor identical inputs, Environment
ReproI run the same seeded input on TPU and GPU (saved GPU output, then compared on TPU): import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import jax
import jax.numpy as jnp
import torch
import numpy as np
jax.config.update("jax_enable_x64", True)
rng = np.random.default_rng(seed=0)
x = rng.normal(size=(1024, 1024))
y = rng.normal(size=(1024, 1024))
x_j = jnp.array(x, dtype=jnp.float32)
y_j = jnp.array(y, dtype=jnp.float32)
o = jax.lax.div(x_j, y_j)def to_numpy_f32(arr):
if isinstance(arr, torch.Tensor):
arr = arr.detach().cpu().numpy()
else:
arr = np.asarray(arr)
return np.ascontiguousarray(arr.astype(np.float32, copy=False))
def ulp_diff_float32(a, b):
a_i = a.view(np.int32).astype(np.int64)
b_i = b.view(np.int32).astype(np.int64)
a_ordered = np.where(a_i < 0, 0x80000000 - a_i, a_i)
b_ordered = np.where(b_i < 0, 0x80000000 - b_i, b_i)
return np.abs(a_ordered - b_ordered)
def compare_metrics(name, test, ref):
test = to_numpy_f32(test)
ref = to_numpy_f32(ref)
finite_mask = np.isfinite(test) & np.isfinite(ref)
skipped = test.size - int(finite_mask.sum())
test_f = test[finite_mask]
ref_f = ref[finite_mask]
abs_err = np.abs(test_f - ref_f)
rel_denom = np.maximum(np.abs(ref_f), np.finfo(np.float32).eps)
rel_err = abs_err / rel_denom
ulp_err = ulp_diff_float32(test_f, ref_f)
print(f"\n{name}")
print(
f" valid elements: {test_f.size}/{test.size}, skipped non-finite pairs: {skipped}"
)
print(
f" abs_err: max={abs_err.max():.6e}, mean={abs_err.mean():.6e}, p99={np.percentile(abs_err, 99):.6e}"
)
print(
f" rel_err: max={rel_err.max():.6e}, mean={rel_err.mean():.6e}, p99={np.percentile(rel_err, 99):.6e}"
)
print(
f" ulp_err: max={int(ulp_err.max())}, mean={ulp_err.mean():.6f}, p99={np.percentile(ulp_err, 99):.6f}"
)Observed metrics (TPU vs GPU, fp32)
Max-diff example:
Extra notes
Questions
|
Beta Was this translation helpful? Give feedback.
Answered by
hawkinsp
Feb 26, 2026
Replies: 1 comment 1 reply
-
|
Yes, this is expected. TPU implements floating point division as multiplication by the reciprocal. The best we could really do here is document it. |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
lingebeng
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Yes, this is expected. TPU implements floating point division as multiplication by the reciprocal. The best we could really do here is document it.