1- """
2- Data Module for AI-based Data Assimilation
3-
4- Handles the loading and preprocessing of first-guess states and observations
5- for the AI-based assimilation approach.
6- """
7-
81import warnings
92from typing import Dict , Optional , Tuple
103
169
1710
1811class AIAssimilationDataset (Dataset ):
19- """
20- Dataset for AI-based data assimilation.
21-
22- Each sample contains a first-guess state and corresponding observations.
23- The dataset is designed to work with self-supervised learning where
24- no ground-truth analysis is required.
25- """
26-
2712 def __init__ (
2813 self ,
2914 first_guess_states : torch .Tensor ,
3015 observations : torch .Tensor ,
3116 observation_locations : Optional [torch .Tensor ] = None ,
3217 ):
33- """
34- Initialize the AI assimilation dataset.
35-
36- Args:
37- first_guess_states: First-guess states (background) [num_samples, state_size]
38- observations: Observation values [num_samples, obs_size]
39- observation_locations: Optional tensor indicating observation locations
40- """
4118 self .first_guess_states = first_guess_states
4219 self .observations = observations
4320 self .observation_locations = observation_locations
@@ -51,15 +28,6 @@ def __len__(self) -> int:
5128 return len (self .first_guess_states )
5229
5330 def __getitem__ (self , idx : int ) -> Dict [str , torch .Tensor ]:
54- """
55- Get a sample from the dataset.
56-
57- Args:
58- idx: Index of the sample
59-
60- Returns:
61- Dictionary containing first_guess, observations, and optionally locations
62- """
6331 sample = {
6432 "first_guess" : self .first_guess_states [idx ],
6533 "observations" : self .observations [idx ],
@@ -80,21 +48,6 @@ def generate_synthetic_assimilation_data(
8048 spatial_correlation : bool = False ,
8149 grid_shape : Optional [Tuple [int , int ]] = None ,
8250) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
83- """
84- Generate synthetic data for AI-based data assimilation experiments.
85-
86- Args:
87- num_samples: Number of samples to generate
88- state_size: Size of the state vector
89- obs_fraction: Fraction of state variables that have observations
90- bg_error_std: Standard deviation of background (first-guess) errors
91- obs_error_std: Standard deviation of observation errors
92- spatial_correlation: Whether to add spatial correlation to the data
93- grid_shape: Shape of spatial grid if applicable (h, w)
94-
95- Returns:
96- Tuple of (first_guess, observations, true_state) tensors
97- """
9851 # Generate a true state with possible spatial correlation
9952 if spatial_correlation and grid_shape is not None :
10053 h , w = grid_shape
@@ -158,12 +111,6 @@ def generate_synthetic_assimilation_data(
158111
159112
160113class AIAssimilationDataModule :
161- """
162- Data module for AI-based assimilation following PyTorch Lightning pattern.
163-
164- Handles data splits and provides train/val/test loaders.
165- """
166-
167114 def __init__ (
168115 self ,
169116 num_samples : int = 1000 ,
@@ -178,22 +125,7 @@ def __init__(
178125 spatial_correlation : bool = False ,
179126 grid_shape : Optional [Tuple [int , int ]] = None ,
180127 ):
181- """
182- Initialize the AI assimilation data module.
183-
184- Args:
185- num_samples: Number of total samples
186- state_size: Size of state vector
187- obs_fraction: Fraction of observed values
188- bg_error_std: Background error standard deviation
189- obs_error_std: Observation error standard deviation
190- batch_size: Batch size for data loaders
191- train_ratio: Fraction for training
192- val_ratio: Fraction for validation
193- test_ratio: Fraction for testing
194- spatial_correlation: Whether to include spatial correlation
195- grid_shape: Shape of spatial grid if applicable
196- """
128+
197129 self .num_samples = num_samples
198130 self .state_size = state_size
199131 self .obs_fraction = obs_fraction
@@ -215,12 +147,7 @@ def __init__(
215147 self .test_loader = None
216148
217149 def setup (self , stage : Optional [str ] = None ):
218- """
219- Setup the datasets and data loaders.
220150
221- Args:
222- stage: Stage of training (fit, validate, test, predict)
223- """
224151 # Generate synthetic data
225152 first_guess , observations , true_state = generate_synthetic_assimilation_data (
226153 num_samples = self .num_samples ,
@@ -266,17 +193,7 @@ def test_dataloader(self) -> DataLoader:
266193def create_observation_operator (
267194 state_size : int , obs_size : int , obs_locations : Optional [np .ndarray ] = None
268195) -> torch .Tensor :
269- """
270- Create an observation operator matrix H that maps state space to observation space.
271-
272- Args:
273- state_size: Size of the state vector
274- obs_size: Size of the observation vector
275- obs_locations: Specific locations of observations (indices in state vector)
276196
277- Returns:
278- Observation operator H [obs_size, state_size]
279- """
280197 if obs_locations is None :
281198 # Randomly select observation locations
282199 obs_indices = np .random .choice (state_size , size = obs_size , replace = False )
0 commit comments