Skip to content

Commit 290616e

Browse files
committed
Merge branch 'main' of github.com:AlexGraefe/mixed_precision_for_JAX
2 parents 0f0aa88 + efd5c38 commit 290616e

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ loss_scaling = mpx.DynamicLossScaling(loss_scaling=jnp.ones((1,), dtype=jnp.floa
110110
```
111111
The loss_scaling object then must be passed to the training pipeline.
112112

113-
The most important part is the training step. `mpx` makes transforming your training step into mixed precision very easy. As you can see, the only change you have to do is to replace a call to `eqx.filter_value_and_grad` with `mpx.filter_value_and_grad` and call the optimizer via `mpx.optimizer_update`. Also, do not forget to return `loss_scaling` in your step function as it is updated.
113+
The most important part is the training step. `mpx` makes transforming your training step into mixed precision very easy. As you can see, the only change you have to do is to replace a call to `eqx.filter_value_and_grad` with `mpx.filter_value_and_grad` and afterwards call the optimizer via `mpx.optimizer_update`. Also, do not forget to return `loss_scaling` in your step function, because `loss_scaling` is updated.
114114

115115
```python
116116
@eqx.filter_jit

0 commit comments

Comments
 (0)