1-
21import sys
32import os
43import torch
5- import unittest
4+ import pytest
65
76# Add the AI assimilation module directory to the path
87ai_assimilation_path = os .path .join (os .path .dirname (__file__ ), '..' , 'graph_weather' , 'models' , 'ai_assimilation' )
1514import training
1615
1716
18- class TestAIAssimilation (unittest .TestCase ):
19- """Test class for AI-based data assimilation functionality."""
20-
21- def test_model_creation_and_forward_pass (self ):
22- """Test that the AI assimilation model can be created and performs forward pass."""
23- state_size = 20
24- net = model .AIAssimilationNet (state_size = state_size )
25-
26- # Create test inputs
27- first_guess = torch .randn (3 , state_size )
28- observations = torch .randn (3 , state_size )
29-
30- # Forward pass
31- analysis = net (first_guess , observations )
32-
33- # Verify output shape and validity
34- self .assertEqual (analysis .shape , (3 , state_size ), "Output shape should match input batch and state size" )
35- self .assertFalse (torch .isnan (analysis ).any ().item (), "Output should not contain NaN values" )
36- self .assertFalse (torch .isinf (analysis ).any ().item (), "Output should not contain Inf values" )
37-
38- def test_3dvar_loss_function (self ):
39- """Test that the 3D-Var loss function works correctly."""
40- loss_fn = loss .ThreeDVarLoss ()
41-
42- # Create test tensors
43- batch_size = 2
44- state_size = 15
45- analysis = torch .randn (batch_size , state_size )
46- first_guess = torch .randn (batch_size , state_size )
47- observations = torch .randn (batch_size , state_size )
48-
49- # Calculate loss
50- total_loss = loss_fn (analysis , first_guess , observations )
51-
52- # Verify loss properties
53- self .assertEqual (total_loss .dim (), 0 , "Loss should be a scalar tensor" )
54- self .assertGreaterEqual (total_loss .item (), 0 , "Loss should be non-negative" )
55- self .assertFalse (torch .isnan (total_loss ).any ().item (), "Loss should not contain NaN values" )
56- self .assertFalse (torch .isinf (total_loss ).any ().item (), "Loss should not contain Inf values" )
57-
58- def test_synthetic_data_generation (self ):
59- """Test that synthetic data generation works correctly."""
60- # Generate synthetic data
61- num_samples = 10
62- state_size = 25
63- first_guess , observations , true_state = data .generate_synthetic_assimilation_data (
64- num_samples = num_samples ,
65- state_size = state_size ,
66- obs_fraction = 0.6 ,
67- bg_error_std = 0.3 ,
68- obs_error_std = 0.2
69- )
70-
71- # Verify data shapes
72- self .assertEqual (first_guess .shape , (num_samples , state_size ), "First guess should have correct shape" )
73- self .assertEqual (observations .shape , (num_samples , state_size ), "Observations should have correct shape" )
74- self .assertEqual (true_state .shape , (num_samples , state_size ), "True state should have correct shape" )
75-
76- # Verify data validity
77- self .assertFalse (torch .isnan (first_guess ).any ().item (), "First guess should not contain NaN values" )
78- self .assertFalse (torch .isinf (first_guess ).any ().item (), "First guess should not contain Inf values" )
79- self .assertFalse (torch .isnan (observations ).any ().item (), "Observations should not contain NaN values" )
80- self .assertFalse (torch .isinf (observations ).any ().item (), "Observations should not contain Inf values" )
81-
82- def test_dataset_creation (self ):
83- """Test that the AI assimilation dataset works correctly."""
84- # Generate test data
85- num_samples = 8
86- state_size = 12
87- first_guess , observations , _ = data .generate_synthetic_assimilation_data (
88- num_samples = num_samples ,
89- state_size = state_size
90- )
91-
92- # Create dataset
93- dataset = data .AIAssimilationDataset (first_guess , observations )
94-
95- # Verify dataset properties
96- self .assertEqual (len (dataset ), num_samples , "Dataset length should match number of samples" )
97-
98- # Get a sample
99- sample = dataset [0 ]
100-
101- # Verify sample structure
102- self .assertIsInstance (sample , dict , "Sample should be a dictionary" )
103- self .assertIn ('first_guess' , sample , "Sample should contain 'first_guess'" )
104- self .assertIn ('observations' , sample , "Sample should contain 'observations'" )
105-
106- # Verify sample shapes
107- self .assertEqual (sample ['first_guess' ].shape , (state_size ,), "First guess in sample should have correct shape" )
108- self .assertEqual (sample ['observations' ].shape , (state_size ,), "Observations in sample should have correct shape" )
109-
110- def test_trainer_functionality (self ):
111- """Test that the AI assimilation trainer works correctly."""
112- state_size = 10
113-
114- # Create model and loss function
115- net = model .AIAssimilationNet (state_size = state_size )
116- loss_fn = loss .ThreeDVarLoss ()
117-
118- # Create trainer
119- trainer = training .AIBasedAssimilationTrainer (
120- model = net ,
121- loss_fn = loss_fn ,
122- lr = 1e-3 ,
123- device = 'cpu'
124- )
125-
126- # Create test batch
127- batch_fg = torch .randn (2 , state_size )
128- batch_obs = torch .randn (2 , state_size )
129-
130- # Run training step
131- train_loss = trainer .train_step (batch_fg , batch_obs )
132-
133- # Verify training step result
134- self .assertIsInstance (train_loss , float , "Training loss should be a float" )
135- self .assertFalse (torch .isnan (torch .tensor (train_loss )).any ().item (), "Training loss should not be NaN" )
136- self .assertFalse (torch .isinf (torch .tensor (train_loss )).any ().item (), "Training loss should not be Inf" )
17+ def test_model_creation_and_forward_pass ():
18+ state_size = 20
19+ net = model .AIAssimilationNet (state_size = state_size )
20+
21+ # Create test inputs
22+ first_guess = torch .randn (3 , state_size )
23+ observations = torch .randn (3 , state_size )
24+
25+ # Forward pass
26+ analysis = net (first_guess , observations )
27+
28+ # Verify output shape and validity
29+ assert analysis .shape == (3 , state_size ), "Output shape should match input batch and state size"
30+ assert not torch .isnan (analysis ).any ().item (), "Output should not contain NaN values"
31+ assert not torch .isinf (analysis ).any ().item (), "Output should not contain Inf values"
32+
33+
34+ def test_3dvar_loss_function ():
35+ """Test that the 3D-Var loss function works correctly."""
36+ loss_fn = loss .ThreeDVarLoss ()
37+
38+ # Create test tensors
39+ batch_size = 2
40+ state_size = 15
41+ analysis = torch .randn (batch_size , state_size )
42+ first_guess = torch .randn (batch_size , state_size )
43+ observations = torch .randn (batch_size , state_size )
44+
45+ # Calculate loss
46+ total_loss = loss_fn (analysis , first_guess , observations )
47+
48+ # Verify loss properties
49+ assert total_loss .dim () == 0 , "Loss should be a scalar tensor"
50+ assert total_loss >= 0 , "Loss should be non-negative"
51+ assert not torch .isnan (total_loss ).any ().item (), "Loss should not contain NaN values"
52+ assert not torch .isinf (total_loss ).any ().item (), "Loss should not contain Inf values"
13753
13854
139- def run_tests ():
140- """Run all AI assimilation tests."""
141- print ("Running AI Assimilation Tests...\n " )
142-
143- # Create test suite
144- suite = unittest .TestLoader ().loadTestsFromTestCase (TestAIAssimilation )
145-
146- # Run tests
147- runner = unittest .TextTestRunner (verbosity = 2 )
148- result = runner .run (suite )
149-
150- # Print summary
151- print ("\n " + "=" * 50 )
152- if result .wasSuccessful ():
153- print ("ALL TESTS PASSED! AI assimilation module is working correctly." )
154- else :
155- print (" SOME TESTS FAILED! Please check the AI assimilation module." )
156- print (f"Failures: { len (result .failures )} , Errors: { len (result .errors )} " )
157-
158- return result .wasSuccessful ()
55+ def test_dataset_creation ():
56+ # Create test data directly
57+ batch_size = 8
58+ state_size = 12
59+
60+ first_guess = torch .randn (batch_size , state_size )
61+ observations = torch .randn (batch_size , state_size )
62+
63+ # Create dataset
64+ dataset = data .AIAssimilationDataset (first_guess , observations )
65+
66+ # Verify dataset properties
67+ assert len (dataset ) == batch_size , "Dataset length should match number of samples"
68+
69+ # Get a sample
70+ sample = dataset [0 ]
71+
72+ # Verify sample structure
73+ assert isinstance (sample , dict ), "Sample should be a dictionary"
74+ assert 'first_guess' in sample , "Sample should contain 'first_guess'"
75+ assert 'observations' in sample , "Sample should contain 'observations'"
76+
77+ # Verify sample shapes
78+ assert sample ['first_guess' ].shape == (state_size ,), "First guess in sample should have correct shape"
79+ assert sample ['observations' ].shape == (state_size ,), "Observations in sample should have correct shape"
15980
16081
161- if __name__ == "__main__" :
162- success = run_tests ()
163- if success :
164- print ("\n AI assimilation module verification completed successfully!" )
165- print ("\n Components tested:" )
166- print ("- AIAssimilationNet (model): Passed" )
167- print ("- ThreeDVarLoss (loss function): Passed" )
168- print ("- AIAssimilationDataset (data handling): Passed" )
169- print ("- AIBasedAssimilationTrainer (training): Passed" )
170- print ("- Synthetic data generation: Passed" )
171- else :
172- print ("\n AI assimilation module has issues that need to be addressed." )
82+ def test_trainer_functionality ():
83+ state_size = 10
84+
85+ # Create model and loss function
86+ net = model .AIAssimilationNet (state_size = state_size )
87+ loss_fn = loss .ThreeDVarLoss ()
88+
89+ # Create trainer
90+ trainer = training .AIBasedAssimilationTrainer (
91+ model = net ,
92+ loss_fn = loss_fn ,
93+ lr = 1e-3 ,
94+ device = 'cpu'
95+ )
96+
97+ # Create test batch
98+ batch_fg = torch .randn (2 , state_size )
99+ batch_obs = torch .randn (2 , state_size )
100+
101+ # Run training step
102+ train_loss = trainer .train_step (batch_fg , batch_obs )
103+
104+ # Verify training step result
105+ assert isinstance (train_loss , float ), "Training loss should be a float"
106+ assert not torch .isnan (torch .tensor (train_loss )).any ().item (), "Training loss should not be NaN"
107+ assert not torch .isinf (torch .tensor (train_loss )).any ().item (), "Training loss should not be Inf"
0 commit comments