Skip to content

Commit c06c785

Browse files
Merge pull request jax-ml#26668 from jakevdp:sharp-bits
PiperOrigin-RevId: 730915911
2 parents 69a6aaa + f5ca46f commit c06c785

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

docs/notebooks/Common_Gotchas_in_JAX.ipynb

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1309,6 +1309,20 @@
13091309
"\n",
13101310
" ```\n",
13111311
" This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.\n",
1312+
"- When operating on [subnormal](https://en.wikipedia.org/wiki/Subnormal_number)\n",
1313+
" floating point numbers, JAX operations use flush-to-zero semantics on some\n",
1314+
" backends. For example:\n",
1315+
" ```python\n",
1316+
" >>> import jax.numpy as jnp\n",
1317+
" >>> subnormal = jnp.float32(1E-45)\n",
1318+
" >>> subnormal # subnormals are representable\n",
1319+
" Array(1.e-45, dtype=float32)\n",
1320+
" >>> subnormal + 0 # but are flushed to zero within operations\n",
1321+
" Array(0., dtype=float32)\n",
1322+
"\n",
1323+
" ```\n",
1324+
" The detailed operation semantics for subnormal values will generally\n",
1325+
" vary depending on the backend.\n",
13121326
"\n",
13131327
"## 🔪 Sharp bits covered in tutorials\n",
13141328
"- {ref}`control-flow` discusses how to work with the constraints that `jit` imposes on the use of Python control flow and logical operators.\n",

docs/notebooks/Common_Gotchas_in_JAX.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,20 @@ Many such cases are discussed in detail in the sections above; here we list seve
677677

678678
```
679679
This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.
680+
- When operating on [subnormal](https://en.wikipedia.org/wiki/Subnormal_number)
681+
floating point numbers, JAX operations use flush-to-zero semantics on some
682+
backends. For example:
683+
```python
684+
>>> import jax.numpy as jnp
685+
>>> subnormal = jnp.float32(1E-45)
686+
>>> subnormal # subnormals are representable
687+
Array(1.e-45, dtype=float32)
688+
>>> subnormal + 0 # but are flushed to zero within operations
689+
Array(0., dtype=float32)
690+
691+
```
692+
The detailed operation semantics for subnormal values will generally
693+
vary depending on the backend.
680694

681695
## 🔪 Sharp bits covered in tutorials
682696
- {ref}`control-flow` discusses how to work with the constraints that `jit` imposes on the use of Python control flow and logical operators.

0 commit comments

Comments
 (0)