Skip to content

Commit 3254c29

Browse files
committed
Modify christianmodel to be input shape agnostic, default img shape usps: 28x28
1 parent 88fa115 commit 3254c29

File tree

2 files changed

+61
-15
lines changed

2 files changed

+61
-15
lines changed

main.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import numpy as np
22
import torch as th
33
import torch.nn as nn
4+
import wandb
45
from torch.utils.data import DataLoader
56
from torchvision import transforms
67
from tqdm import tqdm
78

8-
import wandb
99
from utils import MetricWrapper, createfolders, get_args, load_data, load_model
1010

1111

@@ -29,30 +29,30 @@ def main():
2929

3030
device = args.device
3131

32-
if args.dataset.lower() in ["usps_0-6", "uspsh5_7_9"]:
33-
augmentations = transforms.Compose(
32+
if "usps" in args.dataset.lower():
33+
transform = transforms.Compose(
3434
[
35-
transforms.Resize((16, 16)),
35+
transforms.Resize((28, 28)),
3636
transforms.ToTensor(),
3737
]
3838
)
3939
else:
40-
augmentations = transforms.Compose([transforms.ToTensor()])
40+
transform = transforms.Compose([transforms.ToTensor()])
4141

4242
# Dataset
4343
traindata = load_data(
4444
args.dataset,
4545
train=True,
4646
data_path=args.datafolder,
4747
download=args.download_data,
48-
transform=augmentations,
48+
transform=transform,
4949
)
5050
validata = load_data(
5151
args.dataset,
5252
train=False,
5353
data_path=args.datafolder,
5454
download=args.download_data,
55-
transform=augmentations,
55+
transform=transform,
5656
)
5757

5858
metrics = MetricWrapper(*args.metric, num_classes=traindata.num_classes)
@@ -113,11 +113,11 @@ def main():
113113

114114
# wandb.login(key=WANDB_API)
115115
wandb.init(
116-
entity="ColabCode-org",
117-
# entity="FYS-8805 Exam",
118-
project="Test",
119-
tags=[args.modelname, args.dataset]
120-
)
116+
entity="ColabCode-org",
117+
# entity="FYS-8805 Exam",
118+
project="Test",
119+
tags=[args.modelname, args.dataset],
120+
)
121121
wandb.watch(model)
122122
exit()
123123
for epoch in range(args.epoch):

utils/models/christian_model.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,18 @@
33

44

55
class CNNBlock(nn.Module):
6+
"""
7+
CNN block with Conv2d, MaxPool2d, and ReLU.
8+
9+
Args
10+
----
11+
12+
in_channels : int
13+
Number of input channels.
14+
out_channels : int
15+
Number of output channels.
16+
"""
17+
618
def __init__(self, in_channels, out_channels):
719
super().__init__()
820

@@ -22,6 +34,37 @@ def forward(self, x):
2234
return x
2335

2436

37+
def find_fc_input_shape(image_shape, *cnn_layers):
38+
"""
39+
Find the shape of the input to the fully connected layer.
40+
41+
Code inspired by @Seilmast (https://github.com/SFI-Visual-Intelligence/Collaborative-Coding-Exam/issues/67#issuecomment-2651212254)
42+
43+
Args
44+
----
45+
image_shape : tuple(int, int, int)
46+
Shape of the input image (C, H, W).
47+
cnn_layers : nn.Module
48+
List of CNN layers.
49+
50+
Returns
51+
-------
52+
int
53+
Number of elements in the input to the fully connected layer.
54+
"""
55+
56+
dummy_img = torch.randn(1, *image_shape)
57+
with torch.no_grad():
58+
x = cnn_layers[0](dummy_img)
59+
60+
for layer in cnn_layers[1:]:
61+
x = layer(x)
62+
63+
x = x.view(x.size(0), -1)
64+
65+
return x.size(1)
66+
67+
2568
class ChristianModel(nn.Module):
2669
"""Simple CNN model for image classification.
2770
@@ -57,7 +100,9 @@ def __init__(self, image_shape, num_classes):
57100
self.cnn1 = CNNBlock(C, 50)
58101
self.cnn2 = CNNBlock(50, 100)
59102

60-
self.fc1 = nn.Linear(100 * 4 * 4, num_classes)
103+
fc_input_shape = find_fc_input_shape(image_shape, self.cnn1, self.cnn2)
104+
105+
self.fc1 = nn.Linear(fc_input_shape, num_classes)
61106

62107
def forward(self, x):
63108
x = self.cnn1(x)
@@ -70,9 +115,10 @@ def forward(self, x):
70115

71116

72117
if __name__ == "__main__":
73-
model = ChristianModel(3, 7)
118+
x = torch.randn(3, 3, 28, 28)
119+
120+
model = ChristianModel(x.shape[1:], 7)
74121

75-
x = torch.randn(3, 3, 16, 16)
76122
y = model(x)
77123

78124
print(y)

0 commit comments

Comments
 (0)