Skip to content

Commit 55ecd35

Browse files
Add FiLM module (PoC) and unit test (#184)
* Add FiLM generator + applier (MetNet-style one-hot) and a basic unit test (PoC) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * film: add Google-style docstrings (module) and delete comment msg * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor: move FiLM module under models/layers and remove re-export --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent df26c79 commit 55ecd35

File tree

2 files changed

+97
-0
lines changed

2 files changed

+97
-0
lines changed
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
5+
class FiLMGenerator(nn.Module):
6+
"""
7+
Generates FiLM parameters (gamma and beta) from a lead-time index.
8+
9+
A one-hot vector for the given lead time is expanded to the batch size
10+
and passed through a small MLP to produce FiLM modulation parameters.
11+
12+
Args:
13+
num_lead_times (int): Number of possible lead-time categories.
14+
hidden_dim (int): Hidden size for the internal MLP.
15+
feature_dim (int): Output dimensionality of gamma and beta.
16+
"""
17+
18+
def __init__(self, num_lead_times: int, hidden_dim: int, feature_dim: int):
19+
super().__init__()
20+
self.num_lead_times = num_lead_times
21+
self.feature_dim = feature_dim
22+
self.network = nn.Sequential(
23+
nn.Linear(num_lead_times, hidden_dim),
24+
nn.ReLU(),
25+
nn.Linear(hidden_dim, 2 * feature_dim),
26+
)
27+
28+
def forward(self, batch_size: int, lead_time: int, device=None):
29+
"""
30+
Compute FiLM gamma and beta parameters.
31+
32+
Args:
33+
batch_size (int): Number of samples to generate parameters for.
34+
lead_time (int): Lead-time index used to construct the one-hot input.
35+
device (optional): Device to place tensors on. Defaults to CPU.
36+
37+
Returns:
38+
Tuple[torch.Tensor, torch.Tensor]:
39+
gamma: Tensor of shape (batch_size, feature_dim).
40+
beta: Tensor of shape (batch_size, feature_dim).
41+
"""
42+
43+
one_hot = torch.zeros(batch_size, self.num_lead_times, device=device)
44+
one_hot[:, lead_time] = 1.0
45+
gamma_beta = self.network(one_hot)
46+
gamma = gamma_beta[:, : self.feature_dim]
47+
beta = gamma_beta[:, self.feature_dim :]
48+
return gamma, beta
49+
50+
51+
class FiLMApplier(nn.Module):
52+
"""
53+
Applies FiLM modulation to an input tensor.
54+
55+
Gamma and beta are broadcast to match the dimensionality of the input,
56+
and the FiLM operation is applied elementwise.
57+
"""
58+
59+
def forward(self, x: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor) -> torch.Tensor:
60+
"""
61+
Apply FiLM conditioning.
62+
63+
Args:
64+
x (torch.Tensor): Input tensor of shape (B, C, ...).
65+
gamma (torch.Tensor): Scaling parameters of shape (B, C).
66+
beta (torch.Tensor): Bias parameters of shape (B, C).
67+
68+
Returns:
69+
torch.Tensor: Output tensor after FiLM modulation, same shape as `x`.
70+
"""
71+
72+
while gamma.ndim < x.ndim:
73+
gamma = gamma.unsqueeze(-1)
74+
beta = beta.unsqueeze(-1)
75+
return x * gamma + beta

tests/test_film.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import torch
2+
from graph_weather.models.layers.film import FiLMGenerator, FiLMApplier
3+
4+
5+
def test_film_shapes():
6+
batch = 4
7+
feature_dim = 16
8+
num_steps = 10
9+
hidden_dim = 8
10+
lead_time = 3
11+
12+
gen = FiLMGenerator(num_steps, hidden_dim, feature_dim)
13+
apply = FiLMApplier()
14+
15+
gamma, beta = gen(batch, lead_time, device="cpu")
16+
17+
assert gamma.shape == (batch, feature_dim)
18+
assert beta.shape == (batch, feature_dim)
19+
20+
x = torch.randn(batch, feature_dim, 8, 8)
21+
out = apply(x, gamma, beta)
22+
assert out.shape == x.shape

0 commit comments

Comments
 (0)