Skip to content

Commit dc7b10b

Browse files
Factorize Models to Prepare HF Upload (#34)
* Factorize models with a BaseModel * Improve AFNO typing * Add tests for the different models * Do not pass dataset metadata to model * Remove unecessary arguments in super Co-authored-by: François Rozet <francois.rozet@outlook.com> --------- Co-authored-by: François Rozet <francois.rozet@outlook.com>
1 parent 7c98aa1 commit dc7b10b

File tree

14 files changed

+224
-132
lines changed

14 files changed

+224
-132
lines changed

tests/models/test_fno.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import pytest
2+
import torch
3+
from hydra.utils import instantiate
4+
from omegaconf import OmegaConf
5+
6+
from the_well.benchmark.models import (
7+
AFNO,
8+
FNO,
9+
TFNO,
10+
)
11+
12+
13+
@pytest.mark.parametrize("fno_model", [FNO, TFNO])
14+
@pytest.mark.parametrize("dim_in", [1, 3, 5])
15+
@pytest.mark.parametrize("dim_out", [1, 3, 5])
16+
@pytest.mark.parametrize("n_spatial_dims", [2, 3])
17+
@pytest.mark.parametrize("spatial_resolution", [16, 32])
18+
def test_fno_model(fno_model, dim_in, dim_out, n_spatial_dims, spatial_resolution):
19+
spatial_resolution = [spatial_resolution] * n_spatial_dims
20+
modes1 = 2
21+
modes2 = 2
22+
modes3 = 2
23+
hidden_channels = 8
24+
model = fno_model(
25+
dim_in,
26+
dim_out,
27+
n_spatial_dims,
28+
spatial_resolution,
29+
modes1,
30+
modes2,
31+
modes3,
32+
hidden_channels,
33+
)
34+
batch_size = 4
35+
input = torch.rand((batch_size, dim_in, *spatial_resolution))
36+
output = model(input)
37+
assert output.shape == (batch_size, dim_out, *spatial_resolution)
38+
39+
40+
@pytest.mark.parametrize("dim_in", [1, 3, 5])
41+
@pytest.mark.parametrize("dim_out", [1, 3, 5])
42+
@pytest.mark.parametrize("n_spatial_dims", [2, 3])
43+
@pytest.mark.parametrize("spatial_resolution", [16, 32])
44+
def test_afno(dim_in, dim_out, n_spatial_dims, spatial_resolution):
45+
spatial_resolution = [spatial_resolution] * n_spatial_dims
46+
model = AFNO(
47+
dim_in,
48+
dim_out,
49+
n_spatial_dims,
50+
spatial_resolution,
51+
hidden_dim=8,
52+
n_blocks=2,
53+
cmlp_diagonal_blocks=1,
54+
)
55+
batch_size = 4
56+
input = torch.rand((batch_size, *spatial_resolution, dim_in))
57+
output = model(input)
58+
assert output.shape == (batch_size, *spatial_resolution, dim_out)
59+
60+
61+
def test_load_fno_conf():
62+
FNO_CONFIG_FILE = "the_well/benchmark/configs/model/fno.yaml"
63+
config = OmegaConf.load(FNO_CONFIG_FILE)
64+
n_spatial_dims = 2
65+
spatial_resolution = [32, 32]
66+
dim_in = 2
67+
dim_out = dim_in
68+
model = instantiate(
69+
config,
70+
n_spatial_dims=n_spatial_dims,
71+
spatial_resolution=spatial_resolution,
72+
dim_in=dim_in,
73+
dim_out=dim_out,
74+
)
75+
assert isinstance(model, FNO)
76+
input = torch.rand(8, dim_in, *spatial_resolution)
77+
output = model(input)
78+
assert output.shape == input.shape

tests/models/test_transformer.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import pytest
2+
import torch
3+
4+
from the_well.benchmark.models import AViT
5+
6+
7+
@pytest.mark.parametrize("dim_in", [1, 3, 5])
8+
@pytest.mark.parametrize("dim_out", [1, 3, 5])
9+
@pytest.mark.parametrize("n_spatial_dims", [2, 3])
10+
@pytest.mark.parametrize("spatial_resolution", [16, 32])
11+
def test_avit(dim_in, dim_out, n_spatial_dims, spatial_resolution):
12+
batch_size = 4
13+
spatial_resolution = [spatial_resolution] * n_spatial_dims
14+
model = AViT(
15+
dim_in,
16+
dim_out,
17+
n_spatial_dims,
18+
spatial_resolution,
19+
hidden_dim=48,
20+
num_heads=2,
21+
processor_blocks=2,
22+
)
23+
input = torch.rand((batch_size, *spatial_resolution, dim_in))
24+
output = model(input)
25+
assert output.shape == (
26+
batch_size,
27+
*spatial_resolution,
28+
dim_out,
29+
)

tests/models/test_unet.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import pytest
2+
import torch
3+
4+
from the_well.benchmark.models import UNetClassic, UNetConvNext
5+
6+
7+
@pytest.mark.parametrize("dim_in", [1, 3, 5])
8+
@pytest.mark.parametrize("dim_out", [1, 3, 5])
9+
@pytest.mark.parametrize("n_spatial_dims", [2, 3])
10+
@pytest.mark.parametrize("spatial_resolution", [16, 32])
11+
def test_unet(dim_in, dim_out, n_spatial_dims, spatial_resolution):
12+
batch_size = 4
13+
spatial_resolution = [spatial_resolution] * n_spatial_dims
14+
model = UNetClassic(dim_in, dim_out, n_spatial_dims, spatial_resolution)
15+
input = torch.rand((batch_size, dim_in, *spatial_resolution))
16+
output = model(input)
17+
assert output.shape == (
18+
batch_size,
19+
dim_out,
20+
*spatial_resolution,
21+
)
22+
23+
24+
@pytest.mark.parametrize("dim_in", [1, 3, 5])
25+
@pytest.mark.parametrize("dim_out", [1, 3, 5])
26+
@pytest.mark.parametrize("n_spatial_dims", [2, 3])
27+
@pytest.mark.parametrize("spatial_resolution", [16, 32])
28+
def test_unet_convnext(dim_in, dim_out, n_spatial_dims, spatial_resolution):
29+
spatial_resolution = [spatial_resolution] * n_spatial_dims
30+
batch_size = 4
31+
model = UNetConvNext(
32+
dim_in, dim_out, n_spatial_dims, spatial_resolution, stages=2, init_features=16
33+
)
34+
input = torch.rand((batch_size, dim_in, *spatial_resolution))
35+
output = model(input)
36+
assert output.shape == (
37+
batch_size,
38+
dim_out,
39+
*spatial_resolution,
40+
)

tests/test_models.py

Lines changed: 0 additions & 62 deletions
This file was deleted.

the_well/benchmark/models/afno/__init__.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from einops import rearrange
1414
from timm.models.layers import DropPath, trunc_normal_
1515

16+
from the_well.benchmark.models.common import BaseModel
17+
1618

1719
class RealImagGELU(nn.Module):
1820
def forward(self, x):
@@ -163,26 +165,25 @@ def forward(self, x):
163165
return x
164166

165167

166-
class AFNO(nn.Module):
168+
class AFNO(BaseModel):
167169
def __init__(
168170
self,
169-
dim_in,
170-
dim_out,
171-
dset_metadata,
172-
hidden_dim=768,
173-
n_blocks=12, # Depth in original code - changing for consistency
174-
cmlp_diagonal_blocks=8, # num_blocks in original
175-
patch_size=8,
176-
mlp_ratio=4.0,
177-
drop_rate=0.0,
178-
drop_path_rate=0.0,
179-
sparsity_threshold=0.01,
171+
dim_in: int,
172+
dim_out: int,
173+
n_spatial_dims: int,
174+
spatial_resolution: tuple[int, ...],
175+
hidden_dim: int = 768,
176+
n_blocks: int = 12, # Depth in original code - changing for consistency
177+
cmlp_diagonal_blocks: int = 8, # num_blocks in original
178+
patch_size: int = 8,
179+
mlp_ratio: float = 4.0,
180+
drop_rate: float = 0.0,
181+
drop_path_rate: float = 0.0,
182+
sparsity_threshold: float = 0.01,
180183
):
181-
super().__init__()
184+
super().__init__(n_spatial_dims, spatial_resolution)
182185
self.dim_in = dim_in
183186
self.dim_out = dim_out
184-
self.resolution = dset_metadata.spatial_resolution
185-
self.n_spatial_dims = dset_metadata.n_spatial_dims
186187
self.n_blocks = n_blocks
187188
self.cmlp_diagonal_blocks = cmlp_diagonal_blocks
188189
norm_layer = partial(nn.LayerNorm, eps=1e-6)
@@ -211,7 +212,7 @@ def __init__(
211212
self.patch_debed = nn.ConvTranspose3d(
212213
hidden_dim, dim_out, kernel_size=patch_size, stride=patch_size
213214
)
214-
self.inner_size = [k // patch_size for k in self.resolution]
215+
self.inner_size = [k // patch_size for k in self.spatial_resolution]
215216
pos_embed_size = [1] + self.inner_size + [hidden_dim]
216217
self.pos_embed = nn.Parameter(0.02 * torch.randn(pos_embed_size))
217218
self.pos_drop = nn.Dropout(p=drop_rate)

the_well/benchmark/models/avit/__init__.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from einops import rearrange
1010
from timm.layers import DropPath
1111

12-
from the_well.data.datasets import WellMetadata
12+
from the_well.benchmark.models.common import BaseModel
1313

1414

1515
class hMLP_stem(nn.Module):
@@ -175,7 +175,7 @@ def forward(self, x):
175175
return input + self.drop_path(self.gamma * x)
176176

177177

178-
class AViT(nn.Module):
178+
class AViT(BaseModel):
179179
"""
180180
Uses axial attention to predict forward dynamics. This simplified version
181181
just stacks time in channels.
@@ -191,36 +191,36 @@ def __init__(
191191
self,
192192
dim_in: int,
193193
dim_out: int,
194-
dset_metadata: WellMetadata,
194+
n_spatial_dims: int,
195+
spatial_resolution: tuple[int, ...],
195196
hidden_dim: int = 768,
196197
num_heads: int = 12,
197198
processor_blocks: int = 8,
198199
drop_path: float = 0.0,
199200
):
200-
super().__init__()
201+
super().__init__(n_spatial_dims, spatial_resolution)
201202
# Normalization - not used in the well
202203
self.drop_path = drop_path
203204
self.dp = np.linspace(0, drop_path, processor_blocks)
204205

205-
self.resolution = tuple(dset_metadata.spatial_resolution)
206206
# Patch size hardcoded at 16 in this implementation
207207
self.patch_size = 16
208208
# Embedding
209-
pe_size = tuple(int(k / self.patch_size) for k in self.resolution) + (
209+
pe_size = tuple(int(k / self.patch_size) for k in self.spatial_resolution) + (
210210
hidden_dim,
211211
)
212212
self.absolute_pe = nn.Parameter(0.02 * torch.randn(*pe_size))
213213
self.embed = hMLP_stem(
214214
dim_in=dim_in,
215215
hidden_dim=hidden_dim,
216-
n_spatial_dims=dset_metadata.n_spatial_dims,
216+
n_spatial_dims=self.n_spatial_dims,
217217
)
218218
self.blocks = nn.ModuleList(
219219
[
220220
AxialAttentionBlock(
221221
hidden_dim=hidden_dim,
222222
num_heads=num_heads,
223-
n_spatial_dims=dset_metadata.n_spatial_dims,
223+
n_spatial_dims=self.n_spatial_dims,
224224
drop_path=self.dp[i],
225225
)
226226
for i in range(processor_blocks)
@@ -229,12 +229,12 @@ def __init__(
229229
self.debed = hMLP_output(
230230
hidden_dim=hidden_dim,
231231
dim_out=dim_out,
232-
n_spatial_dims=dset_metadata.n_spatial_dims,
232+
n_spatial_dims=self.n_spatial_dims,
233233
)
234234

235-
if dset_metadata.n_spatial_dims == 2:
235+
if self.n_spatial_dims == 2:
236236
self.embed_reshapes = ["b h w c -> b c h w", "b c h w -> b h w c"]
237-
if dset_metadata.n_spatial_dims == 3:
237+
if self.n_spatial_dims == 3:
238238
self.embed_reshapes = ["b h w d c -> b c h w d", "b c h w d -> b h w d c"]
239239

240240
def forward(self, x):

the_well/benchmark/models/common.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,21 @@
11
import torch
22
import torch.nn as nn
33
import torch.nn.functional as F
4+
from huggingface_hub import PyTorchModelHubMixin
45
from torch.nn.utils.parametrizations import spectral_norm
56

67

8+
class BaseModel(nn.Module, PyTorchModelHubMixin):
9+
def __init__(
10+
self,
11+
n_spatial_dims: int,
12+
spatial_resolution: tuple[int, ...],
13+
):
14+
super().__init__()
15+
self.n_spatial_dims = n_spatial_dims
16+
self.spatial_resolution = spatial_resolution
17+
18+
719
class NestedLinear(nn.Linear):
820
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
921
super().__init__(

0 commit comments

Comments
 (0)