Skip to content

Commit ecb6db4

Browse files
committed
Changed the input of load_model to enable models to process all datasets
1 parent 5af2c61 commit ecb6db4

File tree

6 files changed

+111
-11
lines changed

6 files changed

+111
-11
lines changed

main.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -139,16 +139,15 @@ def main():
139139
data_path=args.datafolder,
140140
)
141141

142-
# Find number of channels in the dataset
143-
if len(traindata[0][0].shape) == 2:
144-
channels = 1
145-
else:
146-
channels = traindata[0][0].shape[0]
142+
# Find the shape of the data, if is 2D, add a channel dimension
143+
data_shape = traindata[0][0].shape
144+
if len(data_shape) == 2:
145+
data_shape = (1, *data_shape)
147146

148147
# load model
149148
model = load_model(
150149
args.modelname,
151-
in_channels=channels,
150+
image_shape=data_shape,
152151
num_classes=traindata.num_classes,
153152
)
154153
model.to(device)

utils/dataloaders/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__all__ = ["USPSDataset0_6"]
1+
__all__ = ["USPSDataset0_6","MNISTDataset0_3"]
22

33
from .usps_0_6 import USPSDataset0_6
44
from .mnist_0_3 import MNISTDataset0_3

utils/dataloaders/mnist_0_3.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def _chech_is_downloaded(self):
8888
if self.mnist_path.exists():
8989
required_files = ["train-images-idx3-ubyte", "train-labels-idx1-ubyte", "t10k-images-idx3-ubyte", "t10k-labels-idx1-ubyte"]
9090
if all([(self.mnist_path / file).exists() for file in required_files]):
91-
print("Data already downloaded.")
91+
print("MNIST Dataset already downloaded.")
9292
return True
9393
else:
9494
return False
@@ -126,7 +126,9 @@ def __getitem__(self, index):
126126
with open(self.images_path, "rb") as f:
127127
f.seek(16 + index * 28*28) # Jump to image position
128128
image = np.frombuffer(f.read(28*28), dtype=np.uint8).reshape(28, 28) # Read image data
129-
129+
130+
image = np.expand_dims(image, axis=0) # Add channel dimension
131+
130132
if self.transform:
131133
image = self.transform(image)
132134

utils/load_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch.nn as nn
22

3-
from .models import ChristianModel, MagnusModel
3+
from .models import ChristianModel, MagnusModel, JanModel
44

55

66
def load_model(modelname: str, *args, **kwargs) -> nn.Module:
@@ -9,6 +9,8 @@ def load_model(modelname: str, *args, **kwargs) -> nn.Module:
99
return MagnusModel(*args, **kwargs)
1010
case "christianmodel":
1111
return ChristianModel(*args, **kwargs)
12+
case "janmodel":
13+
return JanModel(*args, **kwargs)
1214
case _:
1315
raise ValueError(
1416
f"Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling"

utils/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
__all__ = ["MagnusModel", "ChristianModel"]
1+
__all__ = ["MagnusModel", "ChristianModel", "JanModel"]
22

33
from .christian_model import ChristianModel
44
from .magnus_model import MagnusModel
5+
from .jan_model import JanModel

utils/models/jan_model.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import torch
2+
"""
3+
A simple neural network model for classification tasks.
4+
Parameters
5+
----------
6+
in_channels : int
7+
Number of input channels.
8+
num_classes : int
9+
Number of output classes.
10+
Attributes
11+
----------
12+
in_channels : int
13+
Number of input channels.
14+
num_classes : int
15+
Number of output classes.
16+
fc1 : nn.Linear
17+
First fully connected layer.
18+
fc2 : nn.Linear
19+
Second fully connected layer.
20+
out : nn.Linear
21+
Output fully connected layer.
22+
leaky_relu : nn.LeakyReLU
23+
Leaky ReLU activation function.
24+
flatten : nn.Flatten
25+
Flatten layer to reshape input tensor.
26+
Methods
27+
-------
28+
forward(x)
29+
Defines the forward pass of the model.
30+
"""
31+
import torch.nn as nn
32+
33+
34+
35+
class JanModel(nn.Module):
36+
"""A simple MLP network model for image classification tasks.
37+
38+
Args
39+
----
40+
in_channels : int
41+
Number of input channels.
42+
num_classes : int
43+
Number of classes in the dataset.
44+
45+
Processing Images
46+
-----------------
47+
Input: (N, C, H, W)
48+
N: Batch size
49+
C: Number of input channels
50+
H: Height of the input image
51+
W: Width of the input image
52+
53+
Example:
54+
For grayscale images, C = 1.
55+
56+
Input Image Shape: (5, 1, 28, 28)
57+
flatten Output Shape: (5, 784)
58+
fc1 Output Shape: (5, 100)
59+
fc2 Output Shape: (5, 100)
60+
out Output Shape: (5, num_classes)
61+
"""
62+
def __init__(self, image_shape, num_classes):
63+
super().__init__()
64+
65+
self.in_channels = image_shape[0]
66+
self.height = image_shape[1]
67+
self.width = image_shape[2]
68+
self.num_classes = num_classes
69+
70+
self.fc1 = nn.Linear(self.height * self.width * self.in_channels, 100)
71+
72+
self.fc2 = nn.Linear(100, 100)
73+
74+
self.out = nn.Linear(100, num_classes)
75+
76+
self.leaky_relu = nn.LeakyReLU()
77+
78+
self.flatten = nn.Flatten()
79+
80+
def forward(self, x):
81+
x = self.flatten(x)
82+
x = self.fc1(x)
83+
x = self.leaky_relu(x)
84+
x = self.fc2(x)
85+
x = self.leaky_relu(x)
86+
x = self.out(x)
87+
return x
88+
89+
90+
if __name__ == "__main__":
91+
model = JanModel(2, 4)
92+
93+
x = torch.randn(3, 2, 28, 28)
94+
y = model(x)
95+
96+
print(y)

0 commit comments

Comments
 (0)