Skip to content

Commit dbc27ea

Browse files
authored
feat: Add Stochastic Decomposition Layer and Fix Dependencies (#188)
* feat: Add Stochastic Decomposition Layer and Fix Dependencies * refactor: Move test_stochastic_decomposition.py per review
1 parent d54683d commit dbc27ea

File tree

4 files changed

+161
-1
lines changed

4 files changed

+161
-1
lines changed

graph_weather/data/nnja_ai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from torch.utils.data import Dataset
1616

1717
try:
18-
from nnja import DataCatalog
18+
from nnja_ai import DataCatalog
1919
except ImportError:
2020
raise ImportError("NNJA-AI library not installed. Install with: " "`pip install nnja-ai`")
2121

graph_weather/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@
1212
from .layers.decoder import Decoder
1313
from .layers.encoder import Encoder
1414
from .layers.processor import Processor
15+
from .layers.stochastic_decomposition import StochasticDecompositionLayer
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""Stochastic Decomposition Layers to inject controllable noise into feature maps
2+
3+
In the original paper the Stochastic Decomposition Layer (SDL) is described as
4+
5+
The SDL decomposes the intermediate feature map into a deterministic component (the input)
6+
and a stochastic component (the noise). This decomposition allows the model to separate
7+
the signal processing from the uncertainty quantification.
8+
The stochastic component is generated by modulating random Gaussian noise with a learned
9+
style vector derived from a latent control variable. This architecture enables the
10+
ensemble generation process to be explicitly controlled by the latent variable, rather
11+
than relying on implicit randomness.
12+
13+
The SDL operation is defined as:
14+
Output = x + (alpha * Style(z) * epsilon)
15+
16+
Where x is the deterministic input, z is the latent control vector, and alpha is a
17+
learnable channel-wise scaling factor that determines the magnitude of the stochastic
18+
perturbation.
19+
20+
"""
21+
22+
import torch
23+
import torch.nn as nn
24+
25+
26+
class StochasticDecompositionLayer(nn.Module):
27+
"""Stochastic Decomposition Layer for controllable probabilistic outputs."""
28+
29+
def __init__(self, input_dim: int, latent_dim: int):
30+
"""Initialize the Stochastic Decomposition Layer.
31+
32+
Args:
33+
input_dim: Number of channels in the input feature map
34+
latent_dim: Dimension of the latent control vector
35+
"""
36+
super().__init__()
37+
self.input_dim = input_dim
38+
self.latent_dim = latent_dim
39+
40+
self.alpha = nn.Parameter(torch.zeros(1, input_dim, 1))
41+
42+
self.style_net = nn.Linear(latent_dim, input_dim)
43+
44+
def forward(self, x: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
45+
"""Apply stochastic decomposition to input features.
46+
47+
Args:
48+
x: Input features [Batch, Channels, *Spatial]
49+
z: Latent control vector [Batch, Latent_Dim]
50+
51+
Returns:
52+
Output features with injected stochasticity
53+
"""
54+
if x.size(1) != self.input_dim:
55+
raise ValueError(f"Expected {self.input_dim} channels, got {x.size(1)}")
56+
epsilon = torch.randn_like(x)
57+
58+
style = self.style_net(z) # [B, C]
59+
60+
spatial_dims = x.dim() - 2
61+
for _ in range(spatial_dims):
62+
style = style.unsqueeze(-1)
63+
64+
alpha_broadcast = self.alpha
65+
while alpha_broadcast.dim() < x.dim():
66+
alpha_broadcast = alpha_broadcast.unsqueeze(-1)
67+
68+
return x + (alpha_broadcast * style * epsilon)
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
"""Tests for the Stochastic Decomposition Layer."""
2+
3+
import pytest
4+
import torch
5+
6+
from graph_weather.models.layers.stochastic_decomposition import (
7+
StochasticDecompositionLayer,
8+
)
9+
10+
11+
@pytest.mark.parametrize(
12+
"shape",
13+
[
14+
(2, 32, 10),
15+
(2, 32, 16, 16),
16+
(2, 32, 8, 16, 16),
17+
],
18+
)
19+
def test_sdl_shapes(shape):
20+
"""Ensure SDL handles arbitrary spatial/temporal dimensions via broadcasting"""
21+
batch, channels = shape[0], shape[1]
22+
latent_dim = 16
23+
24+
x = torch.randn(*shape)
25+
z = torch.randn(batch, latent_dim)
26+
27+
model = StochasticDecompositionLayer(input_dim=channels, latent_dim=latent_dim)
28+
out = model(x, z)
29+
30+
assert out.shape == shape
31+
assert not torch.isnan(out).any()
32+
33+
34+
def test_initialization_is_deterministic():
35+
"""Alpha initialized to 0 should imply Identity function initially"""
36+
x = torch.randn(2, 64, 32, 32)
37+
z = torch.randn(2, 16)
38+
39+
model = StochasticDecompositionLayer(input_dim=64, latent_dim=16)
40+
41+
assert torch.allclose(model.alpha, torch.zeros_like(model.alpha))
42+
43+
out = model(x, z)
44+
assert torch.allclose(out, x, atol=1e-6)
45+
46+
47+
def test_reproducibility():
48+
"""Fixed seed + fixed latent = fixed output"""
49+
x = torch.randn(2, 16, 10)
50+
z = torch.randn(2, 8)
51+
52+
model = StochasticDecompositionLayer(16, 8)
53+
54+
with torch.no_grad():
55+
model.alpha.fill_(0.5)
56+
57+
torch.manual_seed(42)
58+
out1 = model(x, z)
59+
60+
torch.manual_seed(42)
61+
out2 = model(x, z)
62+
63+
assert torch.equal(out1, out2)
64+
65+
66+
def test_gradient_flow():
67+
"""Test that gradients flow correctly through the layer."""
68+
x = torch.randn(2, 16, 10, requires_grad=True)
69+
z = torch.randn(2, 8, requires_grad=True)
70+
71+
model = StochasticDecompositionLayer(16, 8)
72+
with torch.no_grad():
73+
model.alpha.fill_(0.1)
74+
75+
out = model(x, z)
76+
loss = out.sum()
77+
loss.backward()
78+
79+
assert model.style_net.weight.grad is not None
80+
assert model.alpha.grad is not None
81+
assert x.grad is not None
82+
83+
84+
def test_channel_mismatch_error():
85+
"""Test that channel mismatch raises ValueError."""
86+
x = torch.randn(2, 32, 10)
87+
z = torch.randn(2, 8)
88+
model = StochasticDecompositionLayer(input_dim=16, latent_dim=8)
89+
90+
with pytest.raises(ValueError):
91+
model(x, z)

0 commit comments

Comments
 (0)