Skip to content

Commit 89075c8

Browse files
committed
added some more explanations
1 parent 9e29cca commit 89075c8

File tree

3 files changed

+20
-22
lines changed

3 files changed

+20
-22
lines changed

README.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,19 @@ The gradient transformations might return gradients that are infinite. In this c
8686
The following provides a small example, training a vision transformer on Cifar100 presenting all the important features of `mpx`. For details, please visit examples/train_vit.py.
8787
This example will not go into the details for the neural network part, but just the `mpx` relevant parts.
8888

89-
When loading the datasets, instantiating the models etc., you must instantiate the loss scaling. Typically, the initial value is set to the maximum value of `float16`.
89+
### Installation and Execution of the Example
90+
First install JAX for your hardware.
91+
Then, install all dependencies via
92+
```bash
93+
pip install -r examples/requirements.txt
94+
```
95+
Then you can run the example via. ATTENTION: The script downloads Cifar100.
96+
```bash
97+
python -m examples.train_vit
98+
```
99+
100+
### Explanation
101+
The loss scaling has to be initialized during the instantiation of the datasets, models etc. Typically, the initial value is set to the maximum value of `float16`.
90102

91103
```python
92104

examples/requirements.txt

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,6 @@ idna==3.10
2525
immutabledict==4.2.1
2626
importlib_resources==6.5.2
2727
iniconfig==2.1.0
28-
jax==0.6.0
29-
jax-cuda12-pjrt==0.6.0
30-
jax-cuda12-plugin==0.6.0
31-
jaxlib==0.6.0
3228
jaxtyping==0.3.2
3329
keras==3.9.2
3430
kiwisolver==1.4.8
@@ -41,16 +37,6 @@ mdurl==0.1.2
4137
ml_dtypes==0.5.1
4238
namex==0.0.9
4339
numpy==2.1.3
44-
nvidia-cublas-cu12==12.9.0.13
45-
nvidia-cuda-cupti-cu12==12.9.19
46-
nvidia-cuda-nvcc-cu12==12.9.41
47-
nvidia-cuda-runtime-cu12==12.9.37
48-
nvidia-cudnn-cu12==9.10.1.4
49-
nvidia-cufft-cu12==11.4.0.6
50-
nvidia-cusolver-cu12==11.7.4.40
51-
nvidia-cusparse-cu12==12.5.9.5
52-
nvidia-nccl-cu12==2.26.5
53-
nvidia-nvjitlink-cu12==12.9.41
5440
opt_einsum==3.4.0
5541
optax==0.2.4
5642
optree==0.15.0

examples/train_vit.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,9 @@ def init_tf_dataloader_image(data_source, batch_size, num_epochs, seed, resoluti
188188
data = data.as_numpy_iterator()
189189
return data
190190

191-
train_dataset = init_tf_dataloader_image(train_data_source, config["batch_size"], config["num_epochs"], 0, 32)
192-
val_dataset = init_tf_dataloader_image(val_data_source, config["batch_size"], config["num_epochs"], 0, 32)
191+
# we make the resolution way too high for CIFAR100, but this is just for testing and to force the training to use a lot of memory.
192+
train_dataset = init_tf_dataloader_image(train_data_source, config["batch_size"], config["num_epochs"], 0, 224)
193+
val_dataset = init_tf_dataloader_image(val_data_source, config["batch_size"], config["num_epochs"], 0, 224)
193194

194195
#########################################
195196
# Sharding
@@ -226,7 +227,7 @@ def init_tf_dataloader_image(data_source, batch_size, num_epochs, seed, resoluti
226227
# Load optimizer
227228
########################################
228229
# optimizer strategy from https://arxiv.org/abs/2106.10270
229-
duration_linear_schedule = 1000
230+
duration_linear_schedule = 100
230231
linear_schedule = optax.linear_schedule(
231232
init_value=config["learning_rate"] * 0.01,
232233
end_value=config["learning_rate"],
@@ -322,14 +323,13 @@ def init_tf_dataloader_image(data_source, batch_size, num_epochs, seed, resoluti
322323
if __name__ == "__main__":
323324
config = {
324325
"train_mixed_precision": True,
325-
"batch_size": 512,
326+
"batch_size": 256,
326327
"num_epochs": 10,
327-
"num_features": 128,
328+
"num_features": 256,
328329
"num_heads": 4,
329-
"num_features_residual": 256,
330+
"num_features_residual": 800,
330331
"num_transformer_blocks": 12,
331332
"learning_rate": 0.001,
332-
"batch_size": 128,
333333
"weight_regularization": 0.001,
334334
}
335335
main(config)

0 commit comments

Comments
 (0)