Skip to content

Commit 57e4892

Browse files
committed
Pulled main and fixed pathing
2 parents d90cf35 + 018b669 commit 57e4892

File tree

8 files changed

+121
-15
lines changed

8 files changed

+121
-15
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ _build/
77
bin/*
88
wandb/*
99
wandb_api.py
10+
doc/autoapi
1011

1112
#Magnus specific
1213
job*

CollaborativeCoding/dataloaders/svhn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def __init__(
2929
AssertionError: If the split is not 'train' or 'test'.
3030
"""
3131
super().__init__()
32+
3233
self.data_path = data_path
3334
self.split = "train" if train else "test"
3435

@@ -55,6 +56,7 @@ def _download_data(self, path: str):
5556
path (str): The directory where the dataset will be downloaded.
5657
"""
5758
print(f"Downloading SVHN data into {path}")
59+
5860
SVHN(path, split=self.split, download=True)
5961
data = loadmat(os.path.join(path, f"{self.split}_32x32.mat"))
6062

@@ -93,7 +95,6 @@ def __getitem__(self, index):
9395

9496
if self.nr_channels == 1:
9597
img = img.convert("L")
96-
9798
if self.transforms is not None:
9899
img = self.transforms(img)
99100

CollaborativeCoding/dataloaders/uspsh5_7_9.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from pathlib import Path
2+
13
import h5py
24
import numpy as np
35
import torch
@@ -30,7 +32,7 @@ class USPSH5_Digit_7_9_Dataset(Dataset):
3032
A transform function to apply to the images.
3133
"""
3234

33-
def __init__(self, h5_path, mode, transform=None):
35+
def __init__(self, data_path, train=False, transform=None):
3436
super().__init__()
3537
"""
3638
Initializes the USPS dataset by loading images and labels from the given `.h5` file.
@@ -43,12 +45,15 @@ def __init__(self, h5_path, mode, transform=None):
4345
transform : callable, optional, default=None
4446
A transform function to apply on images.
4547
"""
46-
48+
self.filename = "usps.h5"
49+
path = data_path if isinstance(data_path, Path) else Path(data_path)
50+
self.filepath = path / self.filename
4751
self.transform = transform
48-
self.mode = mode
49-
self.h5_path = h5_path
52+
self.mode = "train" if train else "test"
53+
self.h5_path = data_path / self.filename
54+
5055
# Load the dataset from the HDF5 file
51-
with h5py.File(self.h5_path, "r") as hf:
56+
with h5py.File(self.filepath, "r") as hf:
5257
images = hf[self.mode]["data"][:]
5358
labels = hf[self.mode]["target"][:]
5459

@@ -105,8 +110,8 @@ def main():
105110

106111
# Load the dataset
107112
dataset = USPSH5_Digit_7_9_Dataset(
108-
h5_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git/usps.h5",
109-
mode="train",
113+
data_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git",
114+
train=False,
110115
transform=transform,
111116
)
112117
data_loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)

CollaborativeCoding/metrics/F1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def _macro_F1(self):
112112

113113
def forward(self, preds, target):
114114
"""
115+
115116
Update the True Positives, False Positives, and False Negatives, and compute the F1 score.
116117
117118
This method computes the F1 score based on the predictions and true labels. It can compute either the

CollaborativeCoding/models/christian_model.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,18 @@
33

44

55
class CNNBlock(nn.Module):
6+
"""
7+
CNN block with Conv2d, MaxPool2d, and ReLU.
8+
9+
Args
10+
----
11+
12+
in_channels : int
13+
Number of input channels.
14+
out_channels : int
15+
Number of output channels.
16+
"""
17+
618
def __init__(self, in_channels, out_channels):
719
super().__init__()
820

@@ -22,6 +34,37 @@ def forward(self, x):
2234
return x
2335

2436

37+
def find_fc_input_shape(image_shape, *cnn_layers):
38+
"""
39+
Find the shape of the input to the fully connected layer.
40+
41+
Code inspired by @Seilmast (https://github.com/SFI-Visual-Intelligence/Collaborative-Coding-Exam/issues/67#issuecomment-2651212254)
42+
43+
Args
44+
----
45+
image_shape : tuple(int, int, int)
46+
Shape of the input image (C, H, W).
47+
cnn_layers : nn.Module
48+
List of CNN layers.
49+
50+
Returns
51+
-------
52+
int
53+
Number of elements in the input to the fully connected layer.
54+
"""
55+
56+
dummy_img = torch.randn(1, *image_shape)
57+
with torch.no_grad():
58+
x = cnn_layers[0](dummy_img)
59+
60+
for layer in cnn_layers[1:]:
61+
x = layer(x)
62+
63+
x = x.view(x.size(0), -1)
64+
65+
return x.size(1)
66+
67+
2568
class ChristianModel(nn.Module):
2669
"""Simple CNN model for image classification.
2770
@@ -57,7 +100,9 @@ def __init__(self, image_shape, num_classes):
57100
self.cnn1 = CNNBlock(C, 50)
58101
self.cnn2 = CNNBlock(50, 100)
59102

60-
self.fc1 = nn.Linear(100 * 4 * 4, num_classes)
103+
fc_input_shape = find_fc_input_shape(image_shape, self.cnn1, self.cnn2)
104+
105+
self.fc1 = nn.Linear(fc_input_shape, num_classes)
61106

62107
def forward(self, x):
63108
x = self.cnn1(x)
@@ -70,9 +115,10 @@ def forward(self, x):
70115

71116

72117
if __name__ == "__main__":
73-
model = ChristianModel(3, 7)
118+
x = torch.randn(3, 3, 28, 28)
119+
120+
model = ChristianModel(x.shape[1:], 7)
74121

75-
x = torch.randn(3, 3, 16, 16)
76122
y = model(x)
77123

78124
print(y)

main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ def main():
3535

3636
device = args.device
3737

38-
if args.dataset.lower() in ["usps_0-6", "usps_7-9"]:
38+
if "usps" in args.dataset.lower():
3939
transform = transforms.Compose(
4040
[
41-
transforms.Resize((16, 16)),
41+
transforms.Resize((28, 28)),
4242
transforms.ToTensor(),
4343
]
4444
)

tests/test_dataloaders.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,45 @@
1-
from CollaborativeCoding.dataloaders.usps_0_6 import USPSDataset0_6
1+
from pathlib import Path
2+
3+
import numpy as np
4+
import pytest
5+
import torch
6+
from PIL import Image
7+
from torchvision import transforms
8+
9+
from CollaborativeCoding.dataloaders import (
10+
MNISTDataset0_3,
11+
USPSDataset0_6,
12+
USPSH5_Digit_7_9_Dataset,
13+
)
14+
from CollaborativeCoding.load_data import load_data
15+
16+
17+
@pytest.mark.parametrize(
18+
"data_name, expected",
19+
[
20+
("usps_0-6", USPSDataset0_6),
21+
("usps_7-9", USPSH5_Digit_7_9_Dataset),
22+
("mnist_0-3", MNISTDataset0_3),
23+
# TODO: Add more datasets here
24+
],
25+
)
26+
def test_load_data(data_name, expected):
27+
dataset = load_data(
28+
data_name,
29+
data_path=Path("data"),
30+
download=True,
31+
transform=transforms.ToTensor(),
32+
)
33+
assert isinstance(dataset, expected)
34+
assert len(dataset) > 0
35+
assert isinstance(dataset[0], tuple)
36+
assert isinstance(dataset[0][0], torch.Tensor)
37+
assert isinstance(
38+
dataset[0][1], (int, torch.Tensor, np.ndarray)
39+
) # Should probably restrict this to only int or one-hot encoded tensor or array for consistency.
240

341

442
def test_uspsdataset0_6():
5-
from pathlib import Path
643
from tempfile import TemporaryDirectory
744

845
import h5py

tests/test_models.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,21 @@ def test_jan_model(image_shape, num_classes):
3434
assert y.shape == (n, num_classes), f"Shape: {y.shape}"
3535

3636

37+
@pytest.mark.parametrize(
38+
"image_shape, num_classes",
39+
[((3, 16, 16), 3), ((3, 16, 16), 7)],
40+
)
41+
def test_solveig_model(image_shape, num_classes):
42+
n, c, h, w = 5, *image_shape
43+
44+
model = SolveigModel(image_shape, num_classes)
45+
46+
x = torch.randn(n, c, h, w)
47+
y = model(x)
48+
49+
assert y.shape == (n, num_classes), f"Shape: {y.shape}"
50+
51+
3752
@pytest.mark.parametrize("image_shape", [(3, 28, 28)])
3853
def test_magnus_model(image_shape):
3954
import torch as th

0 commit comments

Comments
 (0)