Skip to content

Commit e7ff8b9

Browse files
committed
test_ai_assimilation
1 parent 3ade764 commit e7ff8b9

File tree

1 file changed

+88
-153
lines changed

1 file changed

+88
-153
lines changed

tests/test_ai_assimilation.py

Lines changed: 88 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
21
import sys
32
import os
43
import torch
5-
import unittest
4+
import pytest
65

76
# Add the AI assimilation module directory to the path
87
ai_assimilation_path = os.path.join(os.path.dirname(__file__), '..', 'graph_weather', 'models', 'ai_assimilation')
@@ -15,158 +14,94 @@
1514
import training
1615

1716

18-
class TestAIAssimilation(unittest.TestCase):
19-
"""Test class for AI-based data assimilation functionality."""
20-
21-
def test_model_creation_and_forward_pass(self):
22-
"""Test that the AI assimilation model can be created and performs forward pass."""
23-
state_size = 20
24-
net = model.AIAssimilationNet(state_size=state_size)
25-
26-
# Create test inputs
27-
first_guess = torch.randn(3, state_size)
28-
observations = torch.randn(3, state_size)
29-
30-
# Forward pass
31-
analysis = net(first_guess, observations)
32-
33-
# Verify output shape and validity
34-
self.assertEqual(analysis.shape, (3, state_size), "Output shape should match input batch and state size")
35-
self.assertFalse(torch.isnan(analysis).any().item(), "Output should not contain NaN values")
36-
self.assertFalse(torch.isinf(analysis).any().item(), "Output should not contain Inf values")
37-
38-
def test_3dvar_loss_function(self):
39-
"""Test that the 3D-Var loss function works correctly."""
40-
loss_fn = loss.ThreeDVarLoss()
41-
42-
# Create test tensors
43-
batch_size = 2
44-
state_size = 15
45-
analysis = torch.randn(batch_size, state_size)
46-
first_guess = torch.randn(batch_size, state_size)
47-
observations = torch.randn(batch_size, state_size)
48-
49-
# Calculate loss
50-
total_loss = loss_fn(analysis, first_guess, observations)
51-
52-
# Verify loss properties
53-
self.assertEqual(total_loss.dim(), 0, "Loss should be a scalar tensor")
54-
self.assertGreaterEqual(total_loss.item(), 0, "Loss should be non-negative")
55-
self.assertFalse(torch.isnan(total_loss).any().item(), "Loss should not contain NaN values")
56-
self.assertFalse(torch.isinf(total_loss).any().item(), "Loss should not contain Inf values")
57-
58-
def test_synthetic_data_generation(self):
59-
"""Test that synthetic data generation works correctly."""
60-
# Generate synthetic data
61-
num_samples = 10
62-
state_size = 25
63-
first_guess, observations, true_state = data.generate_synthetic_assimilation_data(
64-
num_samples=num_samples,
65-
state_size=state_size,
66-
obs_fraction=0.6,
67-
bg_error_std=0.3,
68-
obs_error_std=0.2
69-
)
70-
71-
# Verify data shapes
72-
self.assertEqual(first_guess.shape, (num_samples, state_size), "First guess should have correct shape")
73-
self.assertEqual(observations.shape, (num_samples, state_size), "Observations should have correct shape")
74-
self.assertEqual(true_state.shape, (num_samples, state_size), "True state should have correct shape")
75-
76-
# Verify data validity
77-
self.assertFalse(torch.isnan(first_guess).any().item(), "First guess should not contain NaN values")
78-
self.assertFalse(torch.isinf(first_guess).any().item(), "First guess should not contain Inf values")
79-
self.assertFalse(torch.isnan(observations).any().item(), "Observations should not contain NaN values")
80-
self.assertFalse(torch.isinf(observations).any().item(), "Observations should not contain Inf values")
81-
82-
def test_dataset_creation(self):
83-
"""Test that the AI assimilation dataset works correctly."""
84-
# Generate test data
85-
num_samples = 8
86-
state_size = 12
87-
first_guess, observations, _ = data.generate_synthetic_assimilation_data(
88-
num_samples=num_samples,
89-
state_size=state_size
90-
)
91-
92-
# Create dataset
93-
dataset = data.AIAssimilationDataset(first_guess, observations)
94-
95-
# Verify dataset properties
96-
self.assertEqual(len(dataset), num_samples, "Dataset length should match number of samples")
97-
98-
# Get a sample
99-
sample = dataset[0]
100-
101-
# Verify sample structure
102-
self.assertIsInstance(sample, dict, "Sample should be a dictionary")
103-
self.assertIn('first_guess', sample, "Sample should contain 'first_guess'")
104-
self.assertIn('observations', sample, "Sample should contain 'observations'")
105-
106-
# Verify sample shapes
107-
self.assertEqual(sample['first_guess'].shape, (state_size,), "First guess in sample should have correct shape")
108-
self.assertEqual(sample['observations'].shape, (state_size,), "Observations in sample should have correct shape")
109-
110-
def test_trainer_functionality(self):
111-
"""Test that the AI assimilation trainer works correctly."""
112-
state_size = 10
113-
114-
# Create model and loss function
115-
net = model.AIAssimilationNet(state_size=state_size)
116-
loss_fn = loss.ThreeDVarLoss()
117-
118-
# Create trainer
119-
trainer = training.AIBasedAssimilationTrainer(
120-
model=net,
121-
loss_fn=loss_fn,
122-
lr=1e-3,
123-
device='cpu'
124-
)
125-
126-
# Create test batch
127-
batch_fg = torch.randn(2, state_size)
128-
batch_obs = torch.randn(2, state_size)
129-
130-
# Run training step
131-
train_loss = trainer.train_step(batch_fg, batch_obs)
132-
133-
# Verify training step result
134-
self.assertIsInstance(train_loss, float, "Training loss should be a float")
135-
self.assertFalse(torch.isnan(torch.tensor(train_loss)).any().item(), "Training loss should not be NaN")
136-
self.assertFalse(torch.isinf(torch.tensor(train_loss)).any().item(), "Training loss should not be Inf")
17+
def test_model_creation_and_forward_pass():
18+
state_size = 20
19+
net = model.AIAssimilationNet(state_size=state_size)
20+
21+
# Create test inputs
22+
first_guess = torch.randn(3, state_size)
23+
observations = torch.randn(3, state_size)
24+
25+
# Forward pass
26+
analysis = net(first_guess, observations)
27+
28+
# Verify output shape and validity
29+
assert analysis.shape == (3, state_size), "Output shape should match input batch and state size"
30+
assert not torch.isnan(analysis).any().item(), "Output should not contain NaN values"
31+
assert not torch.isinf(analysis).any().item(), "Output should not contain Inf values"
32+
33+
34+
def test_3dvar_loss_function():
35+
"""Test that the 3D-Var loss function works correctly."""
36+
loss_fn = loss.ThreeDVarLoss()
37+
38+
# Create test tensors
39+
batch_size = 2
40+
state_size = 15
41+
analysis = torch.randn(batch_size, state_size)
42+
first_guess = torch.randn(batch_size, state_size)
43+
observations = torch.randn(batch_size, state_size)
44+
45+
# Calculate loss
46+
total_loss = loss_fn(analysis, first_guess, observations)
47+
48+
# Verify loss properties
49+
assert total_loss.dim() == 0, "Loss should be a scalar tensor"
50+
assert total_loss >= 0, "Loss should be non-negative"
51+
assert not torch.isnan(total_loss).any().item(), "Loss should not contain NaN values"
52+
assert not torch.isinf(total_loss).any().item(), "Loss should not contain Inf values"
13753

13854

139-
def run_tests():
140-
"""Run all AI assimilation tests."""
141-
print("Running AI Assimilation Tests...\n")
142-
143-
# Create test suite
144-
suite = unittest.TestLoader().loadTestsFromTestCase(TestAIAssimilation)
145-
146-
# Run tests
147-
runner = unittest.TextTestRunner(verbosity=2)
148-
result = runner.run(suite)
149-
150-
# Print summary
151-
print("\n" + "="*50)
152-
if result.wasSuccessful():
153-
print("ALL TESTS PASSED! AI assimilation module is working correctly.")
154-
else:
155-
print(" SOME TESTS FAILED! Please check the AI assimilation module.")
156-
print(f"Failures: {len(result.failures)}, Errors: {len(result.errors)}")
157-
158-
return result.wasSuccessful()
55+
def test_dataset_creation():
56+
# Create test data directly
57+
batch_size = 8
58+
state_size = 12
59+
60+
first_guess = torch.randn(batch_size, state_size)
61+
observations = torch.randn(batch_size, state_size)
62+
63+
# Create dataset
64+
dataset = data.AIAssimilationDataset(first_guess, observations)
65+
66+
# Verify dataset properties
67+
assert len(dataset) == batch_size, "Dataset length should match number of samples"
68+
69+
# Get a sample
70+
sample = dataset[0]
71+
72+
# Verify sample structure
73+
assert isinstance(sample, dict), "Sample should be a dictionary"
74+
assert 'first_guess' in sample, "Sample should contain 'first_guess'"
75+
assert 'observations' in sample, "Sample should contain 'observations'"
76+
77+
# Verify sample shapes
78+
assert sample['first_guess'].shape == (state_size,), "First guess in sample should have correct shape"
79+
assert sample['observations'].shape == (state_size,), "Observations in sample should have correct shape"
15980

16081

161-
if __name__ == "__main__":
162-
success = run_tests()
163-
if success:
164-
print("\nAI assimilation module verification completed successfully!")
165-
print("\nComponents tested:")
166-
print("- AIAssimilationNet (model): Passed")
167-
print("- ThreeDVarLoss (loss function): Passed")
168-
print("- AIAssimilationDataset (data handling): Passed")
169-
print("- AIBasedAssimilationTrainer (training): Passed")
170-
print("- Synthetic data generation: Passed")
171-
else:
172-
print("\n AI assimilation module has issues that need to be addressed.")
82+
def test_trainer_functionality():
83+
state_size = 10
84+
85+
# Create model and loss function
86+
net = model.AIAssimilationNet(state_size=state_size)
87+
loss_fn = loss.ThreeDVarLoss()
88+
89+
# Create trainer
90+
trainer = training.AIBasedAssimilationTrainer(
91+
model=net,
92+
loss_fn=loss_fn,
93+
lr=1e-3,
94+
device='cpu'
95+
)
96+
97+
# Create test batch
98+
batch_fg = torch.randn(2, state_size)
99+
batch_obs = torch.randn(2, state_size)
100+
101+
# Run training step
102+
train_loss = trainer.train_step(batch_fg, batch_obs)
103+
104+
# Verify training step result
105+
assert isinstance(train_loss, float), "Training loss should be a float"
106+
assert not torch.isnan(torch.tensor(train_loss)).any().item(), "Training loss should not be NaN"
107+
assert not torch.isinf(torch.tensor(train_loss)).any().item(), "Training loss should not be Inf"

0 commit comments

Comments
 (0)