Skip to content

Commit dfa0dd7

Browse files
Merge pull request jax-ml#25162 from nireekshak:testbranch
PiperOrigin-RevId: 702375099
2 parents 8c66cba + f43fa9f commit dfa0dd7

File tree

7 files changed

+7
-7
lines changed

7 files changed

+7
-7
lines changed

docs/gpu_performance_tips.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ don't seem useful for multi-host communication yet.
112112

113113
## Multi-Process
114114

115-
We recommand using one process per GPU and not one per node. In some
115+
We recommend using one process per GPU and not one per node. In some
116116
cases, this can speed up jitted computation. The
117117
{func}`jax.distributed.initialize` API will automatically understand
118118
that configuration when run under SLURM. However, this only a rule of

docs/gradient-checkpointing.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ print_fwd_bwd(f, 3.)
443443

444444
When differentiated functions are staged out to XLA for compilation — for example by applying {func}`jax.jit` to a function which contains a {func}`jax.grad` call — XLA will automatically optimize the computation, including decisions about when to compute or rematerialize values. As a result, **{func}`jax.checkpoint` often isn't needed for differentiated functions under a {func}`jax.jit`**. XLA will optimize things for you.
445445

446-
One exception is when using staged-out control flow, like {func}`jax.lax.scan`. Automatic compiler optimizations across multiple control flow primitives (for example, across a forward-pass `scan` and the corresponding backward-pass `scan`), typically aren't aren't as thorough. As a result, it's often a good idea to use {func}`jax.checkpoint` on the body function passed to {func}`jax.lax.scan`.
446+
One exception is when using staged-out control flow, like {func}`jax.lax.scan`. Automatic compiler optimizations across multiple control flow primitives (for example, across a forward-pass `scan` and the corresponding backward-pass `scan`), typically aren't as thorough. As a result, it's often a good idea to use {func}`jax.checkpoint` on the body function passed to {func}`jax.lax.scan`.
447447

448448
For example, one common pattern in large [Transformer models](https://en.wikipedia.org/wiki/Transformer_(machine_learning_model)) is to express the architecture as a {func}`jax.lax.scan` over layers so as to reduce compilation times. That is, using a simple fully-connected network as an analogy, instead of writing something like this:
449449

docs/notebooks/Common_Gotchas_in_JAX.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@
202202
"id": "cDpQ5u63Ba_H"
203203
},
204204
"source": [
205-
"It is not recommended to use iterators in any JAX function you want to `jit` or in any control-flow primitive. The reason is that an iterator is a python object which introduces state to retrieve the next element. Therefore, it is incompatible with JAX functional programming model. In the code below, there are some examples of incorrect attempts to use iterators with JAX. Most of them return an error, but some give unexpected results."
205+
"It is not recommended to use iterators in any JAX function you want to `jit` or in any control-flow primitive. The reason is that an iterator is a python object which introduces state to retrieve the next element. Therefore, it is incompatible with JAX's functional programming model. In the code below, there are some examples of incorrect attempts to use iterators with JAX. Most of them return an error, but some give unexpected results."
206206
]
207207
},
208208
{

docs/notebooks/Common_Gotchas_in_JAX.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ print(jit(pure_uses_internal_state)(5.))
121121

122122
+++ {"id": "cDpQ5u63Ba_H"}
123123

124-
It is not recommended to use iterators in any JAX function you want to `jit` or in any control-flow primitive. The reason is that an iterator is a python object which introduces state to retrieve the next element. Therefore, it is incompatible with JAX functional programming model. In the code below, there are some examples of incorrect attempts to use iterators with JAX. Most of them return an error, but some give unexpected results.
124+
It is not recommended to use iterators in any JAX function you want to `jit` or in any control-flow primitive. The reason is that an iterator is a python object which introduces state to retrieve the next element. Therefore, it is incompatible with JAX's functional programming model. In the code below, there are some examples of incorrect attempts to use iterators with JAX. Most of them return an error, but some give unexpected results.
125125

126126
```{code-cell} ipython3
127127
:id: w99WXa6bBa_H

docs/notebooks/autodiff_remat.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1129,7 +1129,7 @@
11291129
"source": [
11301130
"When differentiated functions are staged out to XLA for compilation, for example by applying `jax.jit` to a function which contains a `jax.grad` call, XLA will automatically optimize the computation, including decisions about when to compute or rematerialize values. As a result, **`jax.checkpoint` often isn't needed for differentiated functions under a `jax.jit`**. XLA will optimize things for you.\n",
11311131
"\n",
1132-
"One exception is when using staged-out control flow, like `jax.lax.scan`. Automatic compiler optimizations across multiple control flow primitives, e.g. across a forward-pass `scan` and the corresponding backward-pass `scan`, typically aren't aren't as thorough. As a result, it's often a good idea to use `jax.checkpoint` on the body function passed to `jax.lax.scan`.\n",
1132+
"One exception is when using staged-out control flow, like `jax.lax.scan`. Automatic compiler optimizations across multiple control flow primitives, e.g. across a forward-pass `scan` and the corresponding backward-pass `scan`, typically aren't as thorough. As a result, it's often a good idea to use `jax.checkpoint` on the body function passed to `jax.lax.scan`.\n",
11331133
"\n",
11341134
"For example, one common pattern in large [Transformer models](https://en.wikipedia.org/wiki/Transformer_(machine_learning_model)) is to express the architecture as a `jax.lax.scan` over layers so as to reduce compilation times. That is, using a simple fully-connected network as an analogy, instead of writing something like this:"
11351135
]

docs/notebooks/autodiff_remat.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ print_fwd_bwd(f, 3.)
490490

491491
When differentiated functions are staged out to XLA for compilation, for example by applying `jax.jit` to a function which contains a `jax.grad` call, XLA will automatically optimize the computation, including decisions about when to compute or rematerialize values. As a result, **`jax.checkpoint` often isn't needed for differentiated functions under a `jax.jit`**. XLA will optimize things for you.
492492

493-
One exception is when using staged-out control flow, like `jax.lax.scan`. Automatic compiler optimizations across multiple control flow primitives, e.g. across a forward-pass `scan` and the corresponding backward-pass `scan`, typically aren't aren't as thorough. As a result, it's often a good idea to use `jax.checkpoint` on the body function passed to `jax.lax.scan`.
493+
One exception is when using staged-out control flow, like `jax.lax.scan`. Automatic compiler optimizations across multiple control flow primitives, e.g. across a forward-pass `scan` and the corresponding backward-pass `scan`, typically aren't as thorough. As a result, it's often a good idea to use `jax.checkpoint` on the body function passed to `jax.lax.scan`.
494494

495495
For example, one common pattern in large [Transformer models](https://en.wikipedia.org/wiki/Transformer_(machine_learning_model)) is to express the architecture as a `jax.lax.scan` over layers so as to reduce compilation times. That is, using a simple fully-connected network as an analogy, instead of writing something like this:
496496

docs/xla_flags.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ XLA_FLAGS='--flag1=value1 --flag2=value2' python3 source.py
5959
| Flag | Type | Notes |
6060
| ---- | ---- | ----- |
6161
| `xla_tpu_enable_data_parallel_all_reduce_opt` | Boolean (true/false) | Optimization to increase overlap opportunities for DCN (data center networking) all-reduces used for data parallel sharding. |
62-
| `xla_tpu_data_parallel_opt_different_sized_ops` | Boolean (true/false) | Enables pipelining of data parallel ops across multiple iterations even if their output sizes doesn't match what can Be saved in place in the stacked variables. Can increase memory pressure. |
62+
| `xla_tpu_data_parallel_opt_different_sized_ops` | Boolean (true/false) | Enables pipelining of data parallel ops across multiple iterations even if their output sizes don't match what can be saved in place in the stacked variables. Can increase memory pressure. |
6363
| `xla_tpu_enable_async_collective_fusion` | Boolean (true/false) | Enables the pass which fuses async collective communications with compute ops (output/loop-fusion or convolution) that are scheduled between their -start and -done instructions. |
6464
| `xla_tpu_enable_async_collective_fusion_fuse_all_gather` | TristateFlag (true/false/auto) | Enables fusing all-gathers within the AsyncCollectiveFusion pass. <br>If set to `auto`, it will be enabled based on the target. |
6565
| `xla_tpu_enable_async_collective_fusion_multiple_steps` | Boolean (true/false) | Enables continuing the same async collective in multiple steps (fusions) in the AsyncCollectiveFusion pass. |

0 commit comments

Comments
 (0)