Skip to content

Commit 5b210f3

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent be2e589 commit 5b210f3

File tree

1 file changed

+26
-27
lines changed

1 file changed

+26
-27
lines changed

tests/test_ai_assimilation.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919
def test_model_creation_and_forward_pass():
2020
state_size = 20
2121
net = model.AIAssimilationNet(state_size=state_size)
22-
22+
2323
# Create test inputs
2424
first_guess = torch.randn(3, state_size)
2525
observations = torch.randn(3, state_size)
26-
26+
2727
# Forward pass
2828
analysis = net(first_guess, observations)
29-
29+
3030
# Verify output shape and validity
3131
assert analysis.shape == (3, state_size), "Output shape should match input batch and state size"
3232
assert not torch.isnan(analysis).any().item(), "Output should not contain NaN values"
@@ -35,17 +35,17 @@ def test_model_creation_and_forward_pass():
3535

3636
def test_3dvar_loss_function():
3737
loss_fn = loss.ThreeDVarLoss()
38-
38+
3939
# Create test tensors
4040
batch_size = 2
4141
state_size = 15
4242
analysis = torch.randn(batch_size, state_size)
4343
first_guess = torch.randn(batch_size, state_size)
4444
observations = torch.randn(batch_size, state_size)
45-
45+
4646
# Calculate loss
4747
total_loss = loss_fn(analysis, first_guess, observations)
48-
48+
4949
# Verify loss properties
5050
assert total_loss.dim() == 0, "Loss should be a scalar tensor"
5151
assert total_loss >= 0, "Loss should be non-negative"
@@ -57,51 +57,50 @@ def test_dataset_creation():
5757
# Create test data directly
5858
batch_size = 8
5959
state_size = 12
60-
60+
6161
first_guess = torch.randn(batch_size, state_size)
6262
observations = torch.randn(batch_size, state_size)
63-
63+
6464
# Create dataset
6565
dataset = data.AIAssimilationDataset(first_guess, observations)
66-
66+
6767
# Verify dataset properties
6868
assert len(dataset) == batch_size, "Dataset length should match number of samples"
69-
69+
7070
# Get a sample
7171
sample = dataset[0]
72-
72+
7373
# Verify sample structure
7474
assert isinstance(sample, dict), "Sample should be a dictionary"
75-
assert 'first_guess' in sample, "Sample should contain 'first_guess'"
76-
assert 'observations' in sample, "Sample should contain 'observations'"
77-
75+
assert "first_guess" in sample, "Sample should contain 'first_guess'"
76+
assert "observations" in sample, "Sample should contain 'observations'"
77+
7878
# Verify sample shapes
79-
assert sample['first_guess'].shape == (state_size,), "First guess in sample should have correct shape"
80-
assert sample['observations'].shape == (state_size,), "Observations in sample should have correct shape"
79+
assert sample["first_guess"].shape == (
80+
state_size,
81+
), "First guess in sample should have correct shape"
82+
assert sample["observations"].shape == (
83+
state_size,
84+
), "Observations in sample should have correct shape"
8185

8286

8387
def test_trainer_functionality():
8488
state_size = 10
85-
89+
8690
# Create model and loss function
8791
net = model.AIAssimilationNet(state_size=state_size)
8892
loss_fn = loss.ThreeDVarLoss()
89-
93+
9094
# Create trainer
91-
trainer = training.AIBasedAssimilationTrainer(
92-
model=net,
93-
loss_fn=loss_fn,
94-
lr=1e-3,
95-
device='cpu'
96-
)
97-
95+
trainer = training.AIBasedAssimilationTrainer(model=net, loss_fn=loss_fn, lr=1e-3, device="cpu")
96+
9897
# Create test batch
9998
batch_fg = torch.randn(2, state_size)
10099
batch_obs = torch.randn(2, state_size)
101-
100+
102101
# Run training step
103102
train_loss = trainer.train_step(batch_fg, batch_obs)
104-
103+
105104
# Verify training step result
106105
assert isinstance(train_loss, float), "Training loss should be a float"
107106
assert not torch.isnan(torch.tensor(train_loss)).any().item(), "Training loss should not be NaN"

0 commit comments

Comments
 (0)