Skip to content

Commit fac03c7

Browse files
authored
Merge pull request #157 from microsoft/wesselb/fine-tuning-adjustments
Sample fine-tuning environment
2 parents 20567b2 + bd1f979 commit fac03c7

File tree

7 files changed

+175
-70
lines changed

7 files changed

+175
-70
lines changed

.pre-commit-config.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@ ci:
22
autoupdate_commit_msg: "chore: Update pre-commit hooks"
33
autofix_commit_msg: "style: Pre-commit fixes"
44

5-
default_language_version:
6-
python: python3.10
7-
85
repos:
96
- repo: meta
107
hooks:

aurora/model/aurora.py

Lines changed: 59 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,8 @@ def __init__(
145145
surf_stats (dict[str, tuple[float, float]], optional): For these surface-level
146146
variables, adjust the normalisation to the given tuple consisting of a new location
147147
and scale.
148-
bf16_mode (bool, optional): To reduce memory usage, convert the tokens to BF16, run
149-
the backbone in pure BF16, and run the decoder in FP16 AMP. This should enable a
150-
gradient computation. USE AT YOUR OWN RISK. THIS WAS NOT USED DURING THE DEVELOPMENT
151-
OF AURORA AND IS PURELY PROVIDED AS A STARTING POINT FOR FINE-TUNING.
148+
autocast (bool, optional): To reduce memory usage, `torch.autocast` only the backbone
149+
to BF16. This is critical to enable fine-tuning.
152150
level_condition (tuple[int | float, ...], optional): Make the patch embeddings dependent
153151
on pressure level. If you want to enable this feature, provide a tuple of all
154152
possible pressure levels.
@@ -228,6 +226,7 @@ def __init__(
228226
embed_dim=embed_dim,
229227
mlp_ratio=mlp_ratio,
230228
drop_path_rate=drop_path,
229+
attn_drop_rate=drop_rate,
231230
drop_rate=drop_rate,
232231
use_lora=use_lora,
233232
lora_steps=lora_steps,
@@ -252,18 +251,16 @@ def __init__(
252251
modulation_heads=modulation_heads,
253252
)
254253

255-
if autocast and not bf16_mode:
254+
if bf16_mode and not autocast:
256255
warnings.warn(
257-
"The argument `autocast` no longer does anything due to limited utility. "
258-
"Consider instead using `bf16_mode`.",
256+
"`bf16_mode` was removed, because it caused serious issues for gradient "
257+
"computation. `bf16_mode` now automatically activates `autocast`, which will not "
258+
"save as much memory, but should be much more stable.",
259259
stacklevel=2,
260260
)
261+
autocast = True
261262

262-
self.bf16_mode = bf16_mode
263-
264-
if self.bf16_mode:
265-
# We run the backbone in pure BF16.
266-
self.backbone.to(torch.bfloat16)
263+
self.autocast = autocast
267264

268265
def forward(self, batch: Batch) -> Batch:
269266
"""Forward pass.
@@ -327,44 +324,30 @@ def forward(self, batch: Batch) -> Batch:
327324
lead_time=self.timestep,
328325
)
329326

330-
# In BF16 mode, the backbone is run in pure BF16.
331-
if self.bf16_mode:
332-
x = x.to(torch.bfloat16)
333-
x = self.backbone(
334-
x,
335-
lead_time=self.timestep,
336-
patch_res=patch_res,
337-
rollout_step=batch.metadata.rollout_step,
338-
)
339-
340-
# In BF16 mode, the decoder is run in AMP PF16, and the output is converted back to FP32.
341-
# We run in PF16 as opposed to BF16 for improved relative precision.
342-
if self.bf16_mode:
343-
device_type = (
344-
"cuda"
345-
if torch.cuda.is_available()
346-
else "xpu"
347-
if torch.xpu.is_available()
348-
else "cpu"
349-
)
350-
context = torch.autocast(device_type=device_type, dtype=torch.float16)
351-
x = x.to(torch.float16)
327+
if self.autocast:
328+
if torch.cuda.is_available():
329+
device_type = "cuda"
330+
elif torch.xpu.is_available():
331+
device_type = "xpu"
332+
else:
333+
device_type = "cpu"
334+
context = torch.autocast(device_type=device_type, dtype=torch.bfloat16)
352335
else:
353336
context = contextlib.nullcontext()
354337
with context:
355-
pred = self.decoder(
338+
x = self.backbone(
356339
x,
357-
batch,
358340
lead_time=self.timestep,
359341
patch_res=patch_res,
342+
rollout_step=batch.metadata.rollout_step,
360343
)
361-
if self.bf16_mode:
362-
pred = dataclasses.replace(
363-
pred,
364-
surf_vars={k: v.float() for k, v in pred.surf_vars.items()},
365-
static_vars={k: v.float() for k, v in pred.static_vars.items()},
366-
atmos_vars={k: v.float() for k, v in pred.atmos_vars.items()},
367-
)
344+
345+
pred = self.decoder(
346+
x,
347+
batch,
348+
lead_time=self.timestep,
349+
patch_res=patch_res,
350+
)
368351

369352
# Remove batch and history dimension from static variables.
370353
pred = dataclasses.replace(
@@ -520,27 +503,49 @@ def adapt_checkpoint_max_history_size(self, checkpoint: dict[str, torch.Tensor])
520503

521504
checkpoint[name] = new_weight
522505

523-
def configure_activation_checkpointing(self):
506+
def configure_activation_checkpointing(
507+
self,
508+
module_names: tuple[str, ...] = (
509+
"Basic3DDecoderLayer",
510+
"Basic3DEncoderLayer",
511+
"LinearPatchReconstruction",
512+
"Perceiver3DDecoder",
513+
"Perceiver3DEncoder",
514+
"Swin3DTransformerBackbone",
515+
"Swin3DTransformerBlock",
516+
),
517+
) -> None:
524518
"""Configure activation checkpointing.
525519
526520
This is required in order to compute gradients without running out of memory.
521+
522+
Args:
523+
module_names (tuple[str, ...], optional): Names of the modules to checkpoint
524+
on.
525+
526+
Raises:
527+
RuntimeError: If any module specifies in `module_names` was not found and
528+
thus could not be checkpointed.
527529
"""
528-
# Checkpoint these modules:
529-
module_names = (
530-
"Perceiver3DEncoder",
531-
"Swin3DTransformerBackbone",
532-
"Basic3DEncoderLayer",
533-
"Basic3DDecoderLayer",
534-
"Perceiver3DDecoder",
535-
"LinearPatchReconstruction",
536-
)
530+
531+
found: set[str] = set()
537532

538533
def check(x: torch.nn.Module) -> bool:
539534
name = x.__class__.__name__
540-
return name in module_names
535+
if name in module_names:
536+
found.add(name)
537+
return True
538+
else:
539+
return False
541540

542541
apply_activation_checkpointing(self, check_fn=check)
543542

543+
if found != set(module_names):
544+
raise RuntimeError(
545+
f'Could not checkpoint on the following modules: '
546+
f'{", ".join(sorted(set(module_names) - found))}.'
547+
)
548+
544549

545550
class AuroraPretrained(Aurora):
546551
"""Pretrained version of Aurora."""

aurora/model/encoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(
4141
embed_dim: int = 1024,
4242
num_heads: int = 16,
4343
head_dim: int = 64,
44-
drop_rate: float = 0.1,
44+
drop_rate: float = 0.0,
4545
depth: int = 2,
4646
mlp_ratio: float = 4.0,
4747
max_history_size: int = 2,
@@ -66,7 +66,7 @@ def __init__(
6666
Defaults to `16`.
6767
head_dim (int, optional): Dimension of attention heads used in aggregation blocks.
6868
Defaults to `64`.
69-
drop_rate (float, optional): Drop out rate for input patches. Defaults to `0.1`.
69+
drop_rate (float, optional): Drop out rate for input patches. Defaults to `0.0`.
7070
depth (int, optional): Number of Perceiver cross-attention and feed-forward blocks.
7171
Defaults to `2`.
7272
mlp_ratio (float, optional): Ratio of hidden dimensionality to embedding dimensionality

aurora/model/swin3d.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -762,8 +762,8 @@ def __init__(
762762
mlp_ratio: float = 4.0,
763763
qkv_bias: bool = True,
764764
drop_rate: float = 0.0,
765-
attn_drop_rate: float = 0.1,
766-
drop_path_rate: float = 0.1,
765+
attn_drop_rate: float = 0.0,
766+
drop_path_rate: float = 0.0,
767767
lora_steps: int = 40,
768768
lora_mode: LoRAMode = "single",
769769
use_lora: bool = False,
@@ -785,8 +785,8 @@ def __init__(
785785
qkv_bias (bool): If `True`, add a learnable bias to the query, key, and value. Defaults
786786
to `True`.
787787
drop_rate (float): Drop-out rate. Defaults to `0.0`.
788-
attn_drop_rate (float): Attention drop-out rate. Defaults to `0.1`.
789-
drop_path_rate (float): Stochastic depth rate. Defaults to `0.1`.
788+
attn_drop_rate (float): Attention drop-out rate. Defaults to `0.0`.
789+
drop_path_rate (float): Stochastic depth rate. Defaults to `0.0`.
790790
lora_steps (int, optional): Maximum number of LoRA roll-out steps. Defaults to `40`.
791791
lora_mode (str, optional): LoRA mode. `"single"` uses the same LoRA for all roll-out
792792
steps, `"from_second"` uses the same LoRA from the second roll-out step on,

docs/finetuning.md

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,47 @@ model = AuroraPretrained()
1010
model.load_checkpoint()
1111
```
1212

13+
## Basic Fine-Tuning Environment
14+
15+
We provide a very basic Docker image and fine-tuning loop to get you started.
16+
This Docker image is built from a NVIDIA PyTorch base image,
17+
so is tailored to work for NVIDIA GPUs, and has been tested on an 80 GB A100.
18+
The image can be found at `finetuning/Dockerfile` and the fine-tuning
19+
loop at `finetuning/finetune.py`.
20+
Assuming that you have cloned the Aurora repository, you can build and run
21+
the image by running the following from the root of the repository:
22+
23+
```bash
24+
docker build . -t aurora:latest -f finetuning/Dockerfile
25+
docker run --rm -it -v .:/app/aurora \
26+
--gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \
27+
aurora:latest
28+
```
29+
30+
Then, within the image, execute
31+
32+
```bash
33+
python finetuning/finetune.py
34+
```
35+
36+
to run the sample fine-tuning loop.
37+
38+
For example, on Azure, launch a VM with size `Standard_NC24ads_A100_v4`, image
39+
Ubuntu 24.04 LTS (x64), and 256 GB of disk space.
40+
Then [install CUDA](https://learn.microsoft.com/en-us/azure/virtual-machines/linux/n-series-driver-setup).
41+
Be sure to install the latest supported version of the CUDA Toolkit by
42+
checking `nvidia-smi` after installing the drivers with
43+
`sudo ubuntu-drivers autoinstall` and rebooting.
44+
Best performance is achieved with CUDA Toolkit 13.0 or higher, which
45+
requires drivers that support CUDA 13.0 or higher.
46+
Then install Docker with `sudo apt install docker.io`,
47+
set the right permissions for the current user with
48+
`sudo usermod -a -G docker $USER`,
49+
[install the NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html),
50+
and reboot.
51+
You should now be able to clone the repo and build and run the image using
52+
the instructions above.
53+
1354
## Computing Gradients
1455

1556
To compute gradients, you will need an A100 with 80 GB of memory.
@@ -19,13 +60,7 @@ You can do this as follows:
1960
```python
2061
from aurora import AuroraPretrained
2162

22-
model = AuroraPretrained(
23-
# BF16 mode is an EXPERIMENTAL mode that saves memory by running the backbone in pure BF16
24-
# and the decoder in FP16 AMP. This should enable gradient computation. USE AT YOUR OWN RISK.
25-
# THIS WAS NOT USED IN THE DEVELOPMENT OF AURORA AND IS PURELY PROVIDED AS A STARTING POINT
26-
# FOR FINE-TUNING.
27-
bf16_mode=True,
28-
)
63+
model = AuroraPretrained(autocast=True)
2964
model.load_checkpoint()
3065

3166
batch = ... # Load some data.
@@ -39,6 +74,9 @@ loss = ...
3974
loss.backward()
4075
```
4176

77+
Here `autocast` enables AMP with `bfloat16` for only the backbone.
78+
This is necessary to be able to fit gradients in memory.
79+
4280
## Exploding Gradients
4381

4482
When fine-tuning, you may run into very large gradient values.

finetuning/Dockerfile

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
FROM nvcr.io/nvidia/pytorch:25.08-py3
2+
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
3+
4+
WORKDIR /app
5+
SHELL ["/bin/bash", "-c"]
6+
7+
# Create the environment and install the repo in editable mode.
8+
RUN mkdir -p /app/aurora/aurora
9+
COPY pyproject.toml LICENSE.txt /app/aurora/
10+
RUN touch /app/aurora/__init__.py \
11+
&& touch /app/aurora/README.md \
12+
&& uv venv --python 3.13 \
13+
&& SETUPTOOLS_SCM_PRETEND_VERSION=0.0.0 uv pip install -e /app/aurora
14+
15+
# Use the environment automatically.
16+
ENV VIRTUAL_ENV="/app/.venv/"
17+
ENV PATH="/app/.venv/bin:$PATH"
18+
19+
# Let the user enter at `/app/aurora`.
20+
WORKDIR /app/aurora

finetuning/finetune.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""Copyright (c) Microsoft Corporation. Licensed under the MIT license."""
2+
3+
from datetime import datetime
4+
5+
import torch
6+
7+
from aurora import AuroraPretrained, Batch, Metadata
8+
9+
10+
def loss(pred: Batch) -> torch.Tensor:
11+
"""A sample loss function. You should replace this with your own loss function."""
12+
surf_values = prediction.surf_vars.values()
13+
atmos_values = prediction.atmos_vars.values()
14+
return sum((x * x).sum() for x in tuple(surf_values) + tuple(atmos_values))
15+
16+
17+
model = AuroraPretrained(autocast=True)
18+
model.load_checkpoint()
19+
model.configure_activation_checkpointing()
20+
model.train()
21+
model = model.to("cuda")
22+
23+
opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
24+
25+
for i in range(10):
26+
print(f"Step {i}")
27+
28+
# Train on random data. You should replace this with your own data.
29+
batch = Batch(
30+
surf_vars={k: torch.randn(1, 2, 721, 1440) for k in ("2t", "10u", "10v", "msl")},
31+
static_vars={k: torch.randn(721, 1440) for k in ("lsm", "z", "slt")},
32+
atmos_vars={k: torch.randn(1, 2, 13, 721, 1440) for k in ("z", "u", "v", "t", "q")},
33+
metadata=Metadata(
34+
lat=torch.linspace(90, -90, 721),
35+
lon=torch.linspace(0, 360, 1440 + 1)[:-1],
36+
time=(datetime(2020, 6, 1, 12, 0),),
37+
atmos_levels=(50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000),
38+
),
39+
)
40+
41+
opt.zero_grad()
42+
prediction = model(batch.to("cuda"))
43+
loss_value = loss(prediction)
44+
loss_value.backward()
45+
opt.step()

0 commit comments

Comments
 (0)