1919def test_model_creation_and_forward_pass ():
2020 state_size = 20
2121 net = model .AIAssimilationNet (state_size = state_size )
22-
22+
2323 # Create test inputs
2424 first_guess = torch .randn (3 , state_size )
2525 observations = torch .randn (3 , state_size )
26-
26+
2727 # Forward pass
2828 analysis = net (first_guess , observations )
29-
29+
3030 # Verify output shape and validity
3131 assert analysis .shape == (3 , state_size ), "Output shape should match input batch and state size"
3232 assert not torch .isnan (analysis ).any ().item (), "Output should not contain NaN values"
@@ -35,17 +35,17 @@ def test_model_creation_and_forward_pass():
3535
3636def test_3dvar_loss_function ():
3737 loss_fn = loss .ThreeDVarLoss ()
38-
38+
3939 # Create test tensors
4040 batch_size = 2
4141 state_size = 15
4242 analysis = torch .randn (batch_size , state_size )
4343 first_guess = torch .randn (batch_size , state_size )
4444 observations = torch .randn (batch_size , state_size )
45-
45+
4646 # Calculate loss
4747 total_loss = loss_fn (analysis , first_guess , observations )
48-
48+
4949 # Verify loss properties
5050 assert total_loss .dim () == 0 , "Loss should be a scalar tensor"
5151 assert total_loss >= 0 , "Loss should be non-negative"
@@ -57,51 +57,50 @@ def test_dataset_creation():
5757 # Create test data directly
5858 batch_size = 8
5959 state_size = 12
60-
60+
6161 first_guess = torch .randn (batch_size , state_size )
6262 observations = torch .randn (batch_size , state_size )
63-
63+
6464 # Create dataset
6565 dataset = data .AIAssimilationDataset (first_guess , observations )
66-
66+
6767 # Verify dataset properties
6868 assert len (dataset ) == batch_size , "Dataset length should match number of samples"
69-
69+
7070 # Get a sample
7171 sample = dataset [0 ]
72-
72+
7373 # Verify sample structure
7474 assert isinstance (sample , dict ), "Sample should be a dictionary"
75- assert ' first_guess' in sample , "Sample should contain 'first_guess'"
76- assert ' observations' in sample , "Sample should contain 'observations'"
77-
75+ assert " first_guess" in sample , "Sample should contain 'first_guess'"
76+ assert " observations" in sample , "Sample should contain 'observations'"
77+
7878 # Verify sample shapes
79- assert sample ['first_guess' ].shape == (state_size ,), "First guess in sample should have correct shape"
80- assert sample ['observations' ].shape == (state_size ,), "Observations in sample should have correct shape"
79+ assert sample ["first_guess" ].shape == (
80+ state_size ,
81+ ), "First guess in sample should have correct shape"
82+ assert sample ["observations" ].shape == (
83+ state_size ,
84+ ), "Observations in sample should have correct shape"
8185
8286
8387def test_trainer_functionality ():
8488 state_size = 10
85-
89+
8690 # Create model and loss function
8791 net = model .AIAssimilationNet (state_size = state_size )
8892 loss_fn = loss .ThreeDVarLoss ()
89-
93+
9094 # Create trainer
91- trainer = training .AIBasedAssimilationTrainer (
92- model = net ,
93- loss_fn = loss_fn ,
94- lr = 1e-3 ,
95- device = 'cpu'
96- )
97-
95+ trainer = training .AIBasedAssimilationTrainer (model = net , loss_fn = loss_fn , lr = 1e-3 , device = "cpu" )
96+
9897 # Create test batch
9998 batch_fg = torch .randn (2 , state_size )
10099 batch_obs = torch .randn (2 , state_size )
101-
100+
102101 # Run training step
103102 train_loss = trainer .train_step (batch_fg , batch_obs )
104-
103+
105104 # Verify training step result
106105 assert isinstance (train_loss , float ), "Training loss should be a float"
107106 assert not torch .isnan (torch .tensor (train_loss )).any ().item (), "Training loss should not be NaN"
0 commit comments