Skip to content

Commit b239407

Browse files
committed
Updation in the files as per the need
1 parent b9ef291 commit b239407

File tree

5 files changed

+3
-396
lines changed

5 files changed

+3
-396
lines changed
Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +0,0 @@
1-
"""
2-
AI-based Data Assimilation Package
3-
4-
Implements AI-based data assimilation following the approach described in:
5-
"AI-Based Data Assimilation: Learning the Functional of Analysis Estimation" (arXiv:2406.00390)
6-
7-
This package provides neural networks that learn to produce optimal analysis states
8-
by minimizing the 3D-Var cost function in a self-supervised manner, without requiring
9-
ground-truth labels.
10-
"""

graph_weather/models/ai_assimilation/data.py

Lines changed: 1 addition & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,3 @@
1-
"""
2-
Data Module for AI-based Data Assimilation
3-
4-
Handles the loading and preprocessing of first-guess states and observations
5-
for the AI-based assimilation approach.
6-
"""
7-
81
import warnings
92
from typing import Dict, Optional, Tuple
103

@@ -16,28 +9,12 @@
169

1710

1811
class AIAssimilationDataset(Dataset):
19-
"""
20-
Dataset for AI-based data assimilation.
21-
22-
Each sample contains a first-guess state and corresponding observations.
23-
The dataset is designed to work with self-supervised learning where
24-
no ground-truth analysis is required.
25-
"""
26-
2712
def __init__(
2813
self,
2914
first_guess_states: torch.Tensor,
3015
observations: torch.Tensor,
3116
observation_locations: Optional[torch.Tensor] = None,
3217
):
33-
"""
34-
Initialize the AI assimilation dataset.
35-
36-
Args:
37-
first_guess_states: First-guess states (background) [num_samples, state_size]
38-
observations: Observation values [num_samples, obs_size]
39-
observation_locations: Optional tensor indicating observation locations
40-
"""
4118
self.first_guess_states = first_guess_states
4219
self.observations = observations
4320
self.observation_locations = observation_locations
@@ -51,15 +28,6 @@ def __len__(self) -> int:
5128
return len(self.first_guess_states)
5229

5330
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
54-
"""
55-
Get a sample from the dataset.
56-
57-
Args:
58-
idx: Index of the sample
59-
60-
Returns:
61-
Dictionary containing first_guess, observations, and optionally locations
62-
"""
6331
sample = {
6432
"first_guess": self.first_guess_states[idx],
6533
"observations": self.observations[idx],
@@ -80,21 +48,6 @@ def generate_synthetic_assimilation_data(
8048
spatial_correlation: bool = False,
8149
grid_shape: Optional[Tuple[int, int]] = None,
8250
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
83-
"""
84-
Generate synthetic data for AI-based data assimilation experiments.
85-
86-
Args:
87-
num_samples: Number of samples to generate
88-
state_size: Size of the state vector
89-
obs_fraction: Fraction of state variables that have observations
90-
bg_error_std: Standard deviation of background (first-guess) errors
91-
obs_error_std: Standard deviation of observation errors
92-
spatial_correlation: Whether to add spatial correlation to the data
93-
grid_shape: Shape of spatial grid if applicable (h, w)
94-
95-
Returns:
96-
Tuple of (first_guess, observations, true_state) tensors
97-
"""
9851
# Generate a true state with possible spatial correlation
9952
if spatial_correlation and grid_shape is not None:
10053
h, w = grid_shape
@@ -158,12 +111,6 @@ def generate_synthetic_assimilation_data(
158111

159112

160113
class AIAssimilationDataModule:
161-
"""
162-
Data module for AI-based assimilation following PyTorch Lightning pattern.
163-
164-
Handles data splits and provides train/val/test loaders.
165-
"""
166-
167114
def __init__(
168115
self,
169116
num_samples: int = 1000,
@@ -178,22 +125,7 @@ def __init__(
178125
spatial_correlation: bool = False,
179126
grid_shape: Optional[Tuple[int, int]] = None,
180127
):
181-
"""
182-
Initialize the AI assimilation data module.
183-
184-
Args:
185-
num_samples: Number of total samples
186-
state_size: Size of state vector
187-
obs_fraction: Fraction of observed values
188-
bg_error_std: Background error standard deviation
189-
obs_error_std: Observation error standard deviation
190-
batch_size: Batch size for data loaders
191-
train_ratio: Fraction for training
192-
val_ratio: Fraction for validation
193-
test_ratio: Fraction for testing
194-
spatial_correlation: Whether to include spatial correlation
195-
grid_shape: Shape of spatial grid if applicable
196-
"""
128+
197129
self.num_samples = num_samples
198130
self.state_size = state_size
199131
self.obs_fraction = obs_fraction
@@ -215,12 +147,7 @@ def __init__(
215147
self.test_loader = None
216148

217149
def setup(self, stage: Optional[str] = None):
218-
"""
219-
Setup the datasets and data loaders.
220150

221-
Args:
222-
stage: Stage of training (fit, validate, test, predict)
223-
"""
224151
# Generate synthetic data
225152
first_guess, observations, true_state = generate_synthetic_assimilation_data(
226153
num_samples=self.num_samples,
@@ -266,17 +193,7 @@ def test_dataloader(self) -> DataLoader:
266193
def create_observation_operator(
267194
state_size: int, obs_size: int, obs_locations: Optional[np.ndarray] = None
268195
) -> torch.Tensor:
269-
"""
270-
Create an observation operator matrix H that maps state space to observation space.
271-
272-
Args:
273-
state_size: Size of the state vector
274-
obs_size: Size of the observation vector
275-
obs_locations: Specific locations of observations (indices in state vector)
276196

277-
Returns:
278-
Observation operator H [obs_size, state_size]
279-
"""
280197
if obs_locations is None:
281198
# Randomly select observation locations
282199
obs_indices = np.random.choice(state_size, size=obs_size, replace=False)

graph_weather/models/ai_assimilation/loss.py

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,16 @@
1-
"""
2-
Loss Module for AI-based Data Assimilation
3-
4-
Implements physics-based loss functions for training AI-based data assimilation models
5-
without requiring ground-truth labels, following the 3D-Var approach.
6-
"""
7-
81
from typing import Optional, Tuple
92

103
import torch
114
import torch.nn as nn
125

136

147
class ThreeDVarLoss(nn.Module):
15-
"""
16-
3D-Var loss function for AI-based data assimilation.
17-
18-
Implements the traditional 3D-Var cost function that balances fit to
19-
background (first-guess) state and observations.
20-
"""
21-
228
def __init__(
239
self,
2410
background_error_covariance: Optional[torch.Tensor] = None,
2511
observation_error_covariance: Optional[torch.Tensor] = None,
2612
observation_operator: Optional[torch.Tensor] = None,
2713
):
28-
"""
29-
Initialize the 3D-Var loss function.
30-
31-
Args:
32-
background_error_covariance: Background error covariance matrix B
33-
observation_error_covariance: Observation error covariance matrix R
34-
observation_operator: Observation operator matrix H
35-
"""
3614
super().__init__()
3715
self.background_error_covariance = background_error_covariance
3816
self.observation_error_covariance = observation_error_covariance
@@ -44,17 +22,6 @@ def forward(
4422
background: torch.Tensor,
4523
observations: torch.Tensor,
4624
) -> torch.Tensor:
47-
"""
48-
Compute the 3D-Var cost function.
49-
50-
Args:
51-
analysis: Analysis state produced by the AI model
52-
background: Background (first-guess) state
53-
observations: Observation values
54-
55-
Returns:
56-
Total 3D-Var cost as a scalar tensor
57-
"""
5825
# Background term: (x_a - x_b)^T B^{-1} (x_a - x_b)
5926
bg_diff = analysis - background
6027
if self.background_error_covariance is not None:
@@ -93,26 +60,12 @@ def forward(
9360

9461

9562
class PhysicsInformedLoss(nn.Module):
96-
"""
97-
Physics-informed loss combining 3D-Var with physical constraints.
98-
99-
Extends the basic 3D-Var loss with additional physics-based regularization terms.
100-
"""
101-
10263
def __init__(
10364
self,
10465
three_d_var_weight: float = 1.0,
10566
smoothness_weight: float = 0.1,
10667
conservation_weight: float = 0.05,
10768
):
108-
"""
109-
Initialize physics-informed loss.
110-
111-
Args:
112-
three_d_var_weight: Weight for 3D-Var term
113-
smoothness_weight: Weight for spatial smoothness regularization
114-
conservation_weight: Weight for conservation law enforcement
115-
"""
11669
super().__init__()
11770
self.three_d_var_weight = three_d_var_weight
11871
self.smoothness_weight = smoothness_weight
@@ -126,18 +79,6 @@ def forward(
12679
observations: torch.Tensor,
12780
grid_spacing: Optional[float] = None,
12881
) -> Tuple[torch.Tensor, dict]:
129-
"""
130-
Compute physics-informed loss with component breakdown.
131-
132-
Args:
133-
analysis: Analysis state from AI model
134-
background: Background state
135-
observations: Observation values
136-
grid_spacing: Spatial grid spacing for derivative calculations
137-
138-
Returns:
139-
Tuple of (total_loss, loss_components_dict)
140-
"""
14182
# Base 3D-Var loss
14283
three_d_var_loss = self.base_loss(analysis, background, observations)
14384

0 commit comments

Comments
 (0)