Skip to content

Commit c29b585

Browse files
authored
Merge pull request #81 from SFI-Visual-Intelligence/solveig-develope
added nr_channels to dataloader and updated modelto work with any input size
2 parents 3cb36b6 + 76bb5b5 commit c29b585

File tree

2 files changed

+48
-5
lines changed

2 files changed

+48
-5
lines changed

CollaborativeCoding/dataloaders/uspsh5_7_9.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class USPSH5_Digit_7_9_Dataset(Dataset):
3232
A transform function to apply to the images.
3333
"""
3434

35-
def __init__(self, data_path, train=False, transform=None):
35+
def __init__(self, data_path, sample_ids, train=False, transform=None, nr_channels=1):
3636
super().__init__()
3737
"""
3838
Initializes the USPS dataset by loading images and labels from the given `.h5` file.
@@ -51,6 +51,8 @@ def __init__(self, data_path, train=False, transform=None):
5151
self.transform = transform
5252
self.mode = "train" if train else "test"
5353
self.h5_path = data_path / self.filename
54+
self.sample_ids = sample_ids
55+
self.nr_channels = nr_channels
5456

5557
# Load the dataset from the HDF5 file
5658
with h5py.File(self.filepath, "r") as hf:
@@ -107,10 +109,10 @@ def main():
107109
transforms.Normalize((0.5,), (0.5,)), # Normalize to [-1, 1]
108110
]
109111
)
110-
112+
indices = np.array([7, 8, 9])
111113
# Load the dataset
112114
dataset = USPSH5_Digit_7_9_Dataset(
113-
data_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git",
115+
data_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git", sample_ids=indices,
114116
train=False,
115117
transform=transform,
116118
)

CollaborativeCoding/models/solveig_model.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,37 @@
22
import torch.nn as nn
33

44

5+
def find_fc_input_shape(image_shape, model):
6+
"""
7+
Find the shape of the input to the fully connected layer after passing through the convolutional layers.
8+
9+
Code inspired by @Seilmast (https://github.com/SFI-Visual-Intelligence/Collaborative-Coding-Exam/issues/67#issuecomment-2651212254)
10+
11+
Args
12+
----
13+
image_shape : tuple(int, int, int)
14+
Shape of the input image (C, H, W), where C is the number of channels,
15+
H is the height, and W is the width of the image.
16+
model : nn.Module
17+
The CNN model containing the convolutional layers, whose output size is used to
18+
determine the number of input features for the fully connected layer.
19+
20+
Returns
21+
-------
22+
int
23+
The number of elements in the input to the fully connected layer.
24+
"""
25+
26+
dummy_img = torch.randn(1, *image_shape)
27+
with torch.no_grad():
28+
x = model.conv_block1(dummy_img)
29+
x = model.conv_block2(x)
30+
x = model.conv_block3(x)
31+
x = torch.flatten(x, 1)
32+
33+
return x.size(1)
34+
35+
536
class SolveigModel(nn.Module):
637
"""
738
A Convolutional Neural Network model for classification.
@@ -49,9 +80,19 @@ def __init__(self, image_shape, num_classes):
4980
nn.ReLU(),
5081
)
5182

52-
self.fc1 = nn.Linear(100 * 8 * 8, num_classes)
83+
fc_input_size = find_fc_input_shape(image_shape, self)
84+
85+
self.fc1 = nn.Linear(fc_input_size, num_classes)
5386

5487
def forward(self, x):
88+
"""
89+
Defines the forward pass.
90+
Args:
91+
x (torch.Tensor): A four-dimensional tensor with shape
92+
(Batch Size, Channels, Image Height, Image Width).
93+
Returns:
94+
torch.Tensor: The output tensor containing class logits for each input sample.
95+
"""
5596
x = self.conv_block1(x)
5697
x = self.conv_block2(x)
5798
x = self.conv_block3(x)
@@ -63,7 +104,7 @@ def forward(self, x):
63104

64105

65106
if __name__ == "__main__":
66-
x = torch.randn(1, 3, 16, 16)
107+
x = torch.randn(1, 3, 28, 28)
67108

68109
model = SolveigModel(x.shape[1:], 3)
69110

0 commit comments

Comments
 (0)