Skip to content

Commit 59e1d5a

Browse files
authored
Merge pull request #69 from SFI-Visual-Intelligence/usps-inputshape
Usps inputshape, resolve #67 LGTM
2 parents 725cb0b + 0305d46 commit 59e1d5a

File tree

4 files changed

+61
-6
lines changed

4 files changed

+61
-6
lines changed

main.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,12 @@ def main():
3030

3131
device = args.device
3232

33-
if args.dataset.lower() in ["usps_0-6", "usps_7-9"]:
33+
34+
if "usps" in args.dataset.lower():
35+
3436
transform = transforms.Compose(
3537
[
36-
transforms.Resize((16, 16)),
38+
transforms.Resize((28, 28)),
3739
transforms.ToTensor(),
3840
]
3941
)
@@ -45,6 +47,7 @@ def main():
4547
data_dir=args.datafolder,
4648
transform=transform,
4749
val_size=args.val_size,
50+
4851
)
4952

5053
train_metrics = MetricWrapper(
@@ -126,6 +129,7 @@ def main():
126129
project=args.run_name,
127130
tags=[args.modelname, args.dataset],
128131
config=args,
132+
129133
)
130134
wandb.watch(model)
131135

utils/arg_parser.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def get_args():
3333
help="Whether model should be saved or not.",
3434
)
3535

36+
3637
# Data/Model specific values
3738
parser.add_argument(
3839
"--modelname",
@@ -82,6 +83,7 @@ def get_args():
8283
"--macro_averaging",
8384
action="store_true",
8485
help="If the flag is included, the metrics will be calculated using macro averaging.",
86+
8587
)
8688

8789
# Training specific values

utils/dataloaders/svhn.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22

3+
34
import h5py
45
import numpy as np
56
from PIL import Image
@@ -29,6 +30,7 @@ def __init__(
2930
AssertionError: If the split is not 'train' or 'test'.
3031
"""
3132
super().__init__()
33+
3234
self.data_path = data_path
3335
self.split = "train" if train else "test"
3436

@@ -55,6 +57,7 @@ def _download_data(self, path: str):
5557
path (str): The directory where the dataset will be downloaded.
5658
"""
5759
print(f"Downloading SVHN data into {path}")
60+
5861
SVHN(path, split=self.split, download=True)
5962
data = loadmat(os.path.join(path, f"{self.split}_32x32.mat"))
6063

@@ -92,8 +95,8 @@ def __getitem__(self, index):
9295
img = Image.fromarray(h5f["images"][index])
9396

9497
if self.nr_channels == 1:
95-
img = img.convert("L")
9698

99+
img = img.convert("L")
97100
if self.transforms is not None:
98101
img = self.transforms(img)
99102

utils/models/christian_model.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,18 @@
33

44

55
class CNNBlock(nn.Module):
6+
"""
7+
CNN block with Conv2d, MaxPool2d, and ReLU.
8+
9+
Args
10+
----
11+
12+
in_channels : int
13+
Number of input channels.
14+
out_channels : int
15+
Number of output channels.
16+
"""
17+
618
def __init__(self, in_channels, out_channels):
719
super().__init__()
820

@@ -22,6 +34,37 @@ def forward(self, x):
2234
return x
2335

2436

37+
def find_fc_input_shape(image_shape, *cnn_layers):
38+
"""
39+
Find the shape of the input to the fully connected layer.
40+
41+
Code inspired by @Seilmast (https://github.com/SFI-Visual-Intelligence/Collaborative-Coding-Exam/issues/67#issuecomment-2651212254)
42+
43+
Args
44+
----
45+
image_shape : tuple(int, int, int)
46+
Shape of the input image (C, H, W).
47+
cnn_layers : nn.Module
48+
List of CNN layers.
49+
50+
Returns
51+
-------
52+
int
53+
Number of elements in the input to the fully connected layer.
54+
"""
55+
56+
dummy_img = torch.randn(1, *image_shape)
57+
with torch.no_grad():
58+
x = cnn_layers[0](dummy_img)
59+
60+
for layer in cnn_layers[1:]:
61+
x = layer(x)
62+
63+
x = x.view(x.size(0), -1)
64+
65+
return x.size(1)
66+
67+
2568
class ChristianModel(nn.Module):
2669
"""Simple CNN model for image classification.
2770
@@ -57,7 +100,9 @@ def __init__(self, image_shape, num_classes):
57100
self.cnn1 = CNNBlock(C, 50)
58101
self.cnn2 = CNNBlock(50, 100)
59102

60-
self.fc1 = nn.Linear(100 * 4 * 4, num_classes)
103+
fc_input_shape = find_fc_input_shape(image_shape, self.cnn1, self.cnn2)
104+
105+
self.fc1 = nn.Linear(fc_input_shape, num_classes)
61106

62107
def forward(self, x):
63108
x = self.cnn1(x)
@@ -70,9 +115,10 @@ def forward(self, x):
70115

71116

72117
if __name__ == "__main__":
73-
model = ChristianModel(3, 7)
118+
x = torch.randn(3, 3, 28, 28)
119+
120+
model = ChristianModel(x.shape[1:], 7)
74121

75-
x = torch.randn(3, 3, 16, 16)
76122
y = model(x)
77123

78124
print(y)

0 commit comments

Comments
 (0)