Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions tests/models/test_fno.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import pytest
import torch
from hydra.utils import instantiate
from omegaconf import OmegaConf

from the_well.benchmark.models import (
AFNO,
FNO,
TFNO,
)


@pytest.mark.parametrize("fno_model", [FNO, TFNO])
@pytest.mark.parametrize("dim_in", [1, 3, 5])
@pytest.mark.parametrize("dim_out", [1, 3, 5])
@pytest.mark.parametrize("n_spatial_dims", [2, 3])
@pytest.mark.parametrize("spatial_resolution", [16, 32])
def test_fno_model(fno_model, dim_in, dim_out, n_spatial_dims, spatial_resolution):
spatial_resolution = [spatial_resolution] * n_spatial_dims
modes1 = 2
modes2 = 2
modes3 = 2
hidden_channels = 8
model = fno_model(
dim_in,
dim_out,
n_spatial_dims,
spatial_resolution,
modes1,
modes2,
modes3,
hidden_channels,
)
batch_size = 4
input = torch.rand((batch_size, dim_in, *spatial_resolution))
output = model(input)
assert output.shape == (batch_size, dim_out, *spatial_resolution)


@pytest.mark.parametrize("dim_in", [1, 3, 5])
@pytest.mark.parametrize("dim_out", [1, 3, 5])
@pytest.mark.parametrize("n_spatial_dims", [2, 3])
@pytest.mark.parametrize("spatial_resolution", [16, 32])
def test_afno(dim_in, dim_out, n_spatial_dims, spatial_resolution):
spatial_resolution = [spatial_resolution] * n_spatial_dims
model = AFNO(
dim_in,
dim_out,
n_spatial_dims,
spatial_resolution,
hidden_dim=8,
n_blocks=2,
cmlp_diagonal_blocks=1,
)
batch_size = 4
input = torch.rand((batch_size, *spatial_resolution, dim_in))
output = model(input)
assert output.shape == (batch_size, *spatial_resolution, dim_out)


def test_load_fno_conf():
FNO_CONFIG_FILE = "the_well/benchmark/configs/model/fno.yaml"
config = OmegaConf.load(FNO_CONFIG_FILE)
n_spatial_dims = 2
spatial_resolution = [32, 32]
dim_in = 2
dim_out = dim_in
model = instantiate(
config,
n_spatial_dims=n_spatial_dims,
spatial_resolution=spatial_resolution,
dim_in=dim_in,
dim_out=dim_out,
)
assert isinstance(model, FNO)
input = torch.rand(8, dim_in, *spatial_resolution)
output = model(input)
assert output.shape == input.shape
29 changes: 29 additions & 0 deletions tests/models/test_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest
import torch

from the_well.benchmark.models import AViT


@pytest.mark.parametrize("dim_in", [1, 3, 5])
@pytest.mark.parametrize("dim_out", [1, 3, 5])
@pytest.mark.parametrize("n_spatial_dims", [2, 3])
@pytest.mark.parametrize("spatial_resolution", [16, 32])
def test_avit(dim_in, dim_out, n_spatial_dims, spatial_resolution):
batch_size = 4
spatial_resolution = [spatial_resolution] * n_spatial_dims
model = AViT(
dim_in,
dim_out,
n_spatial_dims,
spatial_resolution,
hidden_dim=48,
num_heads=2,
processor_blocks=2,
)
input = torch.rand((batch_size, *spatial_resolution, dim_in))
output = model(input)
assert output.shape == (
batch_size,
*spatial_resolution,
dim_out,
)
40 changes: 40 additions & 0 deletions tests/models/test_unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest
import torch

from the_well.benchmark.models import UNetClassic, UNetConvNext


@pytest.mark.parametrize("dim_in", [1, 3, 5])
@pytest.mark.parametrize("dim_out", [1, 3, 5])
@pytest.mark.parametrize("n_spatial_dims", [2, 3])
@pytest.mark.parametrize("spatial_resolution", [16, 32])
def test_unet(dim_in, dim_out, n_spatial_dims, spatial_resolution):
batch_size = 4
spatial_resolution = [spatial_resolution] * n_spatial_dims
model = UNetClassic(dim_in, dim_out, n_spatial_dims, spatial_resolution)
input = torch.rand((batch_size, dim_in, *spatial_resolution))
output = model(input)
assert output.shape == (
batch_size,
dim_out,
*spatial_resolution,
)


@pytest.mark.parametrize("dim_in", [1, 3, 5])
@pytest.mark.parametrize("dim_out", [1, 3, 5])
@pytest.mark.parametrize("n_spatial_dims", [2, 3])
@pytest.mark.parametrize("spatial_resolution", [16, 32])
def test_unet_convnext(dim_in, dim_out, n_spatial_dims, spatial_resolution):
spatial_resolution = [spatial_resolution] * n_spatial_dims
batch_size = 4
model = UNetConvNext(
dim_in, dim_out, n_spatial_dims, spatial_resolution, stages=2, init_features=16
)
input = torch.rand((batch_size, dim_in, *spatial_resolution))
output = model(input)
assert output.shape == (
batch_size,
dim_out,
*spatial_resolution,
)
62 changes: 0 additions & 62 deletions tests/test_models.py

This file was deleted.

33 changes: 17 additions & 16 deletions the_well/benchmark/models/afno/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from einops import rearrange
from timm.models.layers import DropPath, trunc_normal_

from the_well.benchmark.models.common import BaseModel


class RealImagGELU(nn.Module):
def forward(self, x):
Expand Down Expand Up @@ -163,26 +165,25 @@ def forward(self, x):
return x


class AFNO(nn.Module):
class AFNO(BaseModel):
def __init__(
self,
dim_in,
dim_out,
dset_metadata,
hidden_dim=768,
n_blocks=12, # Depth in original code - changing for consistency
cmlp_diagonal_blocks=8, # num_blocks in original
patch_size=8,
mlp_ratio=4.0,
drop_rate=0.0,
drop_path_rate=0.0,
sparsity_threshold=0.01,
dim_in: int,
dim_out: int,
n_spatial_dims: int,
spatial_resolution: tuple[int, ...],
hidden_dim: int = 768,
n_blocks: int = 12, # Depth in original code - changing for consistency
cmlp_diagonal_blocks: int = 8, # num_blocks in original
patch_size: int = 8,
mlp_ratio: float = 4.0,
drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
sparsity_threshold: float = 0.01,
):
super().__init__()
super().__init__(n_spatial_dims, spatial_resolution)
self.dim_in = dim_in
self.dim_out = dim_out
self.resolution = dset_metadata.spatial_resolution
self.n_spatial_dims = dset_metadata.n_spatial_dims
self.n_blocks = n_blocks
self.cmlp_diagonal_blocks = cmlp_diagonal_blocks
norm_layer = partial(nn.LayerNorm, eps=1e-6)
Expand Down Expand Up @@ -211,7 +212,7 @@ def __init__(
self.patch_debed = nn.ConvTranspose3d(
hidden_dim, dim_out, kernel_size=patch_size, stride=patch_size
)
self.inner_size = [k // patch_size for k in self.resolution]
self.inner_size = [k // patch_size for k in self.spatial_resolution]
pos_embed_size = [1] + self.inner_size + [hidden_dim]
self.pos_embed = nn.Parameter(0.02 * torch.randn(pos_embed_size))
self.pos_drop = nn.Dropout(p=drop_rate)
Expand Down
22 changes: 11 additions & 11 deletions the_well/benchmark/models/avit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from einops import rearrange
from timm.layers import DropPath

from the_well.data.datasets import WellMetadata
from the_well.benchmark.models.common import BaseModel


class hMLP_stem(nn.Module):
Expand Down Expand Up @@ -175,7 +175,7 @@ def forward(self, x):
return input + self.drop_path(self.gamma * x)


class AViT(nn.Module):
class AViT(BaseModel):
"""
Uses axial attention to predict forward dynamics. This simplified version
just stacks time in channels.
Expand All @@ -191,36 +191,36 @@ def __init__(
self,
dim_in: int,
dim_out: int,
dset_metadata: WellMetadata,
n_spatial_dims: int,
spatial_resolution: tuple[int, ...],
hidden_dim: int = 768,
num_heads: int = 12,
processor_blocks: int = 8,
drop_path: float = 0.0,
):
super().__init__()
super().__init__(n_spatial_dims, spatial_resolution)
# Normalization - not used in the well
self.drop_path = drop_path
self.dp = np.linspace(0, drop_path, processor_blocks)

self.resolution = tuple(dset_metadata.spatial_resolution)
# Patch size hardcoded at 16 in this implementation
self.patch_size = 16
# Embedding
pe_size = tuple(int(k / self.patch_size) for k in self.resolution) + (
pe_size = tuple(int(k / self.patch_size) for k in self.spatial_resolution) + (
hidden_dim,
)
self.absolute_pe = nn.Parameter(0.02 * torch.randn(*pe_size))
self.embed = hMLP_stem(
dim_in=dim_in,
hidden_dim=hidden_dim,
n_spatial_dims=dset_metadata.n_spatial_dims,
n_spatial_dims=self.n_spatial_dims,
)
self.blocks = nn.ModuleList(
[
AxialAttentionBlock(
hidden_dim=hidden_dim,
num_heads=num_heads,
n_spatial_dims=dset_metadata.n_spatial_dims,
n_spatial_dims=self.n_spatial_dims,
drop_path=self.dp[i],
)
for i in range(processor_blocks)
Expand All @@ -229,12 +229,12 @@ def __init__(
self.debed = hMLP_output(
hidden_dim=hidden_dim,
dim_out=dim_out,
n_spatial_dims=dset_metadata.n_spatial_dims,
n_spatial_dims=self.n_spatial_dims,
)

if dset_metadata.n_spatial_dims == 2:
if self.n_spatial_dims == 2:
self.embed_reshapes = ["b h w c -> b c h w", "b c h w -> b h w c"]
if dset_metadata.n_spatial_dims == 3:
if self.n_spatial_dims == 3:
self.embed_reshapes = ["b h w d c -> b c h w d", "b c h w d -> b h w d c"]

def forward(self, x):
Expand Down
12 changes: 12 additions & 0 deletions the_well/benchmark/models/common.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import PyTorchModelHubMixin
from torch.nn.utils.parametrizations import spectral_norm


class BaseModel(nn.Module, PyTorchModelHubMixin):
def __init__(
self,
n_spatial_dims: int,
spatial_resolution: tuple[int, ...],
):
super().__init__()
self.n_spatial_dims = n_spatial_dims
self.spatial_resolution = spatial_resolution


class NestedLinear(nn.Linear):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
super().__init__(
Expand Down
Loading