Skip to content

Commit 6f08341

Browse files
committed
Move tests to test directory
1 parent 1254048 commit 6f08341

File tree

8 files changed

+100
-95
lines changed

8 files changed

+100
-95
lines changed

tests/test_createfolders.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from utils import createfolders
2+
3+
4+
def test_createfolders():
5+
import argparse
6+
from pathlib import Path
7+
from tempfile import TemporaryDirectory
8+
9+
with TemporaryDirectory() as temp_dir:
10+
temp_dir = Path(temp_dir)
11+
12+
parser = argparse.ArgumentParser()
13+
14+
# Structuture related values
15+
parser.add_argument(
16+
"--datafolder",
17+
type=Path,
18+
default=temp_dir / "Data",
19+
help="Path to where data will be saved during training.",
20+
)
21+
parser.add_argument(
22+
"--resultfolder",
23+
type=Path,
24+
default=temp_dir / "Results",
25+
help="Path to where results will be saved during evaluation.",
26+
)
27+
parser.add_argument(
28+
"--modelfolder",
29+
type=Path,
30+
default=temp_dir / "Experiments",
31+
help="Path to where model weights will be saved at the end of training.",
32+
)
33+
34+
args = parser.parse_args(
35+
[
36+
"--datafolder",
37+
str(temp_dir / "Data"),
38+
"--resultfolder",
39+
str(temp_dir / "Results"),
40+
"--modelfolder",
41+
str(temp_dir / "Experiments"),
42+
]
43+
)
44+
45+
createfolders(args.datafolder, args.resultfolder, args.modelfolder)
46+
47+
assert (temp_dir / "Data").exists()
48+
assert (temp_dir / "Results").exists()
49+
assert (temp_dir / "Experiments").exists()

tests/test_dataloaders.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from utils.dataloaders.usps_0_6 import USPSDataset0_6
2+
3+
4+
def test_uspsdataset0_6():
5+
from pathlib import Path
6+
7+
import numpy as np
8+
9+
datapath = Path("data/USPS")
10+
11+
dataset = USPSDataset0_6(data_path=datapath, train=True)
12+
assert len(dataset) == 5460
13+
data, target = dataset[0]
14+
assert data.shape == (1, 16, 16)
15+
assert all(target == np.array([0, 0, 0, 0, 0, 0, 1]))

tests/test_metrics.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from utils.metrics import Recall
2+
3+
4+
def test_recall():
5+
import torch
6+
7+
recall = Recall(7)
8+
9+
y_true = torch.tensor([0, 1, 2, 3, 4, 5, 6])
10+
y_pred = torch.tensor([2, 1, 2, 1, 4, 5, 6])
11+
12+
recall_score = recall(y_true, y_pred)
13+
14+
assert recall_score.allclose(torch.tensor(0.7143), atol=1e-5), (
15+
f"Recall Score: {recall_score.item()}"
16+
)

tests/test_models.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import pytest
2+
import torch
3+
4+
from utils.models import ChristianModel
5+
6+
7+
@pytest.mark.parametrize("in_channels, num_classes", [(1, 6), (3, 6)])
8+
def test_christian_model(in_channels, num_classes):
9+
n, c, h, w = 5, in_channels, 16, 16
10+
11+
model = ChristianModel(c, num_classes)
12+
13+
x = torch.randn(n, c, h, w)
14+
y = model(x)
15+
16+
assert y.shape == (n, num_classes), f"Shape: {y.shape}"
17+
assert y.sum(dim=1).allclose(torch.ones(n), atol=1e-5), (
18+
f"Softmax output should sum to 1, but got: {y.sum()}"
19+
)

utils/createfolders.py

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -16,47 +16,3 @@ def createfolders(*dirs: Path) -> None:
1616

1717
for dir in dirs:
1818
dir.mkdir(parents=True, exist_ok=True)
19-
20-
21-
def test_createfolders():
22-
with TemporaryDirectory() as temp_dir:
23-
temp_dir = Path(temp_dir)
24-
25-
parser = argparse.ArgumentParser()
26-
27-
# Structuture related values
28-
parser.add_argument(
29-
"--datafolder",
30-
type=Path,
31-
default=temp_dir / "Data",
32-
help="Path to where data will be saved during training.",
33-
)
34-
parser.add_argument(
35-
"--resultfolder",
36-
type=Path,
37-
default=temp_dir / "Results",
38-
help="Path to where results will be saved during evaluation.",
39-
)
40-
parser.add_argument(
41-
"--modelfolder",
42-
type=Path,
43-
default=temp_dir / "Experiments",
44-
help="Path to where model weights will be saved at the end of training.",
45-
)
46-
47-
args = parser.parse_args(
48-
[
49-
"--datafolder",
50-
temp_dir / "Data",
51-
"--resultfolder",
52-
temp_dir / "Results",
53-
"--modelfolder",
54-
temp_dir / "Experiments",
55-
]
56-
)
57-
58-
createfolders(args.datafolder, args.resultfolder, args.modelfolder)
59-
60-
assert (temp_dir / "Data").exists()
61-
assert (temp_dir / "Results").exists()
62-
assert (temp_dir / "Experiments").exists()

utils/dataloaders/usps_0_6.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -116,19 +116,3 @@ def __getitem__(self, idx):
116116
data = self.transform(data)
117117

118118
return data, target
119-
120-
121-
def test_uspsdataset0_6():
122-
import pytest
123-
124-
datapath = Path("data/USPS/usps.h5")
125-
126-
dataset = USPSDataset0_6(path=datapath, mode="train")
127-
assert len(dataset) == 5460
128-
data, target = dataset[0]
129-
assert data.shape == (16, 16)
130-
assert target == 6
131-
132-
# Test for an invalid mode
133-
with pytest.raises(ValueError):
134-
USPSDataset0_6(path=datapath, mode="inference")

utils/metrics/recall.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -40,23 +40,3 @@ def forward(self, y_true, y_pred):
4040
recall = true_positives / (true_positives + false_negatives)
4141

4242
return recall
43-
44-
45-
def test_recall():
46-
recall = Recall(7)
47-
48-
y_true = torch.tensor([0, 1, 2, 3, 4, 5, 6])
49-
y_pred = torch.tensor([2, 1, 2, 1, 4, 5, 6])
50-
51-
recall_score = recall(y_true, y_pred)
52-
53-
assert recall_score.allclose(torch.tensor(0.7143), atol=1e-5), f"Recall Score: {recall_score.item()}"
54-
55-
56-
def test_one_hot_encode():
57-
num_classes = 7
58-
59-
y_true = torch.tensor([0, 1, 2, 3, 4, 5, 6])
60-
y_onehot = one_hot_encode(y_true, num_classes)
61-
62-
assert y_onehot.shape == (7, 7), f"Shape: {y_onehot.shape}"

utils/models/christian_model.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import pytest
21
import torch
32
import torch.nn as nn
43

@@ -49,6 +48,7 @@ class ChristianModel(nn.Module):
4948
CNN2 Output Shape: (5, 100, 4, 4)
5049
FC Output Shape: (5, num_classes)
5150
"""
51+
5252
def __init__(self, in_channels, num_classes):
5353
super().__init__()
5454

@@ -69,21 +69,7 @@ def forward(self, x):
6969
return x
7070

7171

72-
@pytest.mark.parametrize("in_channels, num_classes", [(1, 6), (3, 6)])
73-
def test_christian_model(in_channels, num_classes):
74-
n, c, h, w = 5, in_channels, 16, 16
75-
76-
model = ChristianModel(c, num_classes)
77-
78-
x = torch.randn(n, c, h, w)
79-
y = model(x)
80-
81-
assert y.shape == (n, num_classes), f"Shape: {y.shape}"
82-
assert y.sum(dim=1).allclose(torch.ones(n), atol=1e-5), f"Softmax output should sum to 1, but got: {y.sum()}"
83-
84-
8572
if __name__ == "__main__":
86-
8773
model = ChristianModel(3, 7)
8874

8975
x = torch.randn(3, 3, 16, 16)

0 commit comments

Comments
 (0)