You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: README.md
+13-1Lines changed: 13 additions & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -86,7 +86,19 @@ The gradient transformations might return gradients that are infinite. In this c
86
86
The following provides a small example, training a vision transformer on Cifar100 presenting all the important features of `mpx`. For details, please visit examples/train_vit.py.
87
87
This example will not go into the details for the neural network part, but just the `mpx` relevant parts.
88
88
89
-
When loading the datasets, instantiating the models etc., you must instantiate the loss scaling. Typically, the initial value is set to the maximum value of `float16`.
89
+
### Installation and Execution of the Example
90
+
First install JAX for your hardware.
91
+
Then, install all dependencies via
92
+
```bash
93
+
pip install -r examples/requirements.txt
94
+
```
95
+
Then you can run the example via. ATTENTION: The script downloads Cifar100.
96
+
```bash
97
+
python -m examples.train_vit
98
+
```
99
+
100
+
### Explanation
101
+
The loss scaling has to be initialized during the instantiation of the datasets, models etc. Typically, the initial value is set to the maximum value of `float16`.
0 commit comments