Skip to content

Commit efd5c38

Browse files
authored
Update README.md
1 parent 89075c8 commit efd5c38

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ loss_scaling = mpx.DynamicLossScaling(loss_scaling=jnp.ones((1,), dtype=jnp.floa
108108
```
109109
The loss_scaling object then must be passed to the training pipeline.
110110

111-
The most important part is the training step. `mpx` makes transforming your training step into mixed precision very easy. As you can see, the only change you have to do is to replace a call to `eqx.filter_value_and_grad` with `mpx.filter_value_and_grad` and call the optimizer via `mpx.optimizer_update`. Also, do not forget to return `loss_scaling` in your step function as it is updated.
111+
The most important part is the training step. `mpx` makes transforming your training step into mixed precision very easy. As you can see, the only change you have to do is to replace a call to `eqx.filter_value_and_grad` with `mpx.filter_value_and_grad` and afterwards call the optimizer via `mpx.optimizer_update`. Also, do not forget to return `loss_scaling` in your step function, because `loss_scaling` is updated.
112112

113113
```python
114114
@eqx.filter_jit

0 commit comments

Comments
 (0)