Skip to content

Commit 07a4ede

Browse files
authored
Merge pull request #31 from SFI-Visual-Intelligence/Jan, closes #23
Change of load_model parameters addresses #23
2 parents a5a6c01 + 5ae4d22 commit 07a4ede

File tree

12 files changed

+294
-66
lines changed

12 files changed

+294
-66
lines changed

main.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def main():
7373
"--dataset",
7474
type=str,
7575
default="svhn",
76-
choices=["svhn", "usps_0-6", "uspsh5_7_9"],
76+
choices=["svhn", "usps_0-6", "uspsh5_7_9", "mnist_0-3"],
7777
help="Which dataset to train the model on.",
7878
)
7979

@@ -139,29 +139,29 @@ def main():
139139
data_path=args.datafolder,
140140
)
141141

142-
# Find number of channels in the dataset
143-
if len(traindata[0][0].shape) == 2:
144-
channels = 1
145-
else:
146-
channels = traindata[0][0].shape[0]
142+
# Find the shape of the data, if is 2D, add a channel dimension
143+
data_shape = traindata[0][0].shape
144+
if len(data_shape) == 2:
145+
data_shape = (1, *data_shape)
147146

148147
# load model
149148
model = load_model(
150149
args.modelname,
151-
in_channels=channels,
150+
image_shape=data_shape,
152151
num_classes=traindata.num_classes,
153152
)
154153
model.to(device)
155154

156-
trainloader = DataLoader(traindata,
157-
batch_size=args.batchsize,
158-
shuffle=True,
159-
pin_memory=True,
160-
drop_last=True)
161-
valiloader = DataLoader(validata,
162-
batch_size=args.batchsize,
163-
shuffle=False,
164-
pin_memory=True)
155+
trainloader = DataLoader(
156+
traindata,
157+
batch_size=args.batchsize,
158+
shuffle=True,
159+
pin_memory=True,
160+
drop_last=True,
161+
)
162+
valiloader = DataLoader(
163+
validata, batch_size=args.batchsize, shuffle=False, pin_memory=True
164+
)
165165

166166
criterion = nn.CrossEntropyLoss()
167167
optimizer = th.optim.Adam(model.parameters(), lr=args.learning_rate)
@@ -171,12 +171,10 @@ def main():
171171
print("Dry run completed")
172172
exit(0)
173173

174-
wandb.init(project='',
175-
tags=[])
174+
wandb.init(project="", tags=[])
176175
wandb.watch(model)
177176

178177
for epoch in range(args.epoch):
179-
180178
# Training loop start
181179
trainingloss = []
182180
model.train()
@@ -201,12 +199,14 @@ def main():
201199
loss = criterion(y, pred)
202200
evalloss.append(loss.item())
203201

204-
wandb.log({
205-
'Epoch': epoch,
206-
'Train loss': np.mean(trainingloss),
207-
'Evaluation Loss': np.mean(evalloss)
208-
})
202+
wandb.log(
203+
{
204+
"Epoch": epoch,
205+
"Train loss": np.mean(trainingloss),
206+
"Evaluation Loss": np.mean(evalloss),
207+
}
208+
)
209209

210210

211-
if __name__ == '__main__':
211+
if __name__ == "__main__":
212212
main()

tests/test_metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from utils.metrics import Recall, F1Score
1+
from utils.metrics import F1Score, Recall
22

33

44
def test_recall():

tests/test_models.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@
44
from utils.models import ChristianModel
55

66

7-
@pytest.mark.parametrize("in_channels, num_classes", [(1, 6), (3, 6)])
8-
def test_christian_model(in_channels, num_classes):
9-
n, c, h, w = 5, in_channels, 16, 16
7+
@pytest.mark.parametrize(
8+
"image_shape, num_classes",
9+
[((1, 16, 16), 6), ((3, 16, 16), 6)],
10+
)
11+
def test_christian_model(image_shape, num_classes):
12+
n, c, h, w = 5, *image_shape
1013

11-
model = ChristianModel(c, num_classes)
14+
model = ChristianModel(image_shape, num_classes)
1215

1316
x = torch.randn(n, c, h, w)
1417
y = model(x)

utils/dataloaders/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
__all__ = ["USPSDataset0_6", "USPSH5_Digit_7_9_Dataset"]
1+
__all__ = ["USPSDataset0_6", "USPSH5_Digit_7_9_Dataset", "MNISTDataset0_3"]
22

3+
from .mnist_0_3 import MNISTDataset0_3
34
from .usps_0_6 import USPSDataset0_6
4-
from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset
5+
from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset

utils/dataloaders/mnist_0_3.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
import gzip
2+
import os
3+
import urllib.request
4+
from pathlib import Path
5+
6+
import numpy as np
7+
from torch.utils.data import Dataset
8+
9+
10+
class MNISTDataset0_3(Dataset):
11+
"""
12+
A custom dataset class for loading MNIST data, specifically for digits 0 through 3.
13+
Parameters
14+
----------
15+
data_path : Path
16+
The root directory where the MNIST data is stored or will be downloaded.
17+
train : bool, optional
18+
If True, loads the training data, otherwise loads the test data. Default is False.
19+
transform : callable, optional
20+
A function/transform that takes in an image and returns a transformed version. Default is None.
21+
download : bool, optional
22+
If True, downloads the dataset if it is not already present in the specified data_path. Default is False.
23+
Attributes
24+
----------
25+
data_path : Path
26+
The root directory where the MNIST data is stored.
27+
mnist_path : Path
28+
The directory where the MNIST data files are stored.
29+
train : bool
30+
Indicates whether the training data or test data is being used.
31+
transform : callable
32+
A function/transform that takes in an image and returns a transformed version.
33+
download : bool
34+
Indicates whether the dataset should be downloaded if not present.
35+
images_path : Path
36+
The path to the image file (training or test) based on the `train` flag.
37+
labels_path : Path
38+
The path to the label file (training or test) based on the `train` flag.
39+
idx : numpy.ndarray
40+
Indices of the labels that are less than 4.
41+
length : int
42+
The number of samples in the dataset.
43+
Methods
44+
-------
45+
_parse_labels(train)
46+
Parses the labels from the label file.
47+
_chech_is_downloaded()
48+
Checks if the dataset is already downloaded.
49+
_download_data()
50+
Downloads and extracts the MNIST dataset.
51+
__len__()
52+
Returns the number of samples in the dataset.
53+
__getitem__(index)
54+
Returns the image and label at the specified index.
55+
"""
56+
57+
def __init__(
58+
self,
59+
data_path: Path,
60+
train: bool = False,
61+
transform=None,
62+
download: bool = False,
63+
):
64+
super().__init__()
65+
66+
self.data_path = data_path
67+
self.mnist_path = self.data_path / "MNIST"
68+
self.train = train
69+
self.transform = transform
70+
self.download = download
71+
self.num_classes = 4
72+
73+
if not self.download and not self._chech_is_downloaded():
74+
raise ValueError(
75+
"Data not found. Set --download-data=True to download the data."
76+
)
77+
if self.download and not self._chech_is_downloaded():
78+
self._download_data()
79+
80+
self.images_path = self.mnist_path / (
81+
"train-images-idx3-ubyte" if train else "t10k-images-idx3-ubyte"
82+
)
83+
self.labels_path = self.mnist_path / (
84+
"train-labels-idx1-ubyte" if train else "t10k-labels-idx1-ubyte"
85+
)
86+
87+
labels = self._parse_labels(train=self.train)
88+
89+
self.idx = np.where(labels < 4)[0]
90+
91+
self.length = len(self.idx)
92+
93+
def _parse_labels(self, train):
94+
with open(self.labels_path, "rb") as f:
95+
data = np.frombuffer(f.read(), dtype=np.uint8, offset=8)
96+
return data
97+
98+
def _chech_is_downloaded(self):
99+
if self.mnist_path.exists():
100+
required_files = [
101+
"train-images-idx3-ubyte",
102+
"train-labels-idx1-ubyte",
103+
"t10k-images-idx3-ubyte",
104+
"t10k-labels-idx1-ubyte",
105+
]
106+
if all([(self.mnist_path / file).exists() for file in required_files]):
107+
print("MNIST Dataset already downloaded.")
108+
return True
109+
else:
110+
return False
111+
else:
112+
self.mnist_path.mkdir(parents=True, exist_ok=True)
113+
return False
114+
115+
def _download_data(self):
116+
urls = {
117+
"train_images": "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz",
118+
"train_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz",
119+
"test_images": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz",
120+
"test_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz",
121+
}
122+
123+
for name, url in urls.items():
124+
file_path = os.path.join(self.mnist_path, url.split("/")[-1])
125+
if not os.path.exists(file_path.replace(".gz", "")): # Avoid re-downloading
126+
urllib.request.urlretrieve(url, file_path)
127+
with gzip.open(file_path, "rb") as f_in:
128+
with open(file_path.replace(".gz", ""), "wb") as f_out:
129+
f_out.write(f_in.read())
130+
os.remove(file_path) # Remove compressed file
131+
132+
def __len__(self):
133+
return self.length
134+
135+
def __getitem__(self, index):
136+
with open(self.labels_path, "rb") as f:
137+
f.seek(8 + index) # Jump to the label position
138+
label = int.from_bytes(f.read(1), byteorder="big") # Read 1 byte for label
139+
140+
with open(self.images_path, "rb") as f:
141+
f.seek(16 + index * 28 * 28) # Jump to image position
142+
image = np.frombuffer(f.read(28 * 28), dtype=np.uint8).reshape(
143+
28, 28
144+
) # Read image data
145+
146+
image = np.expand_dims(image, axis=0) # Add channel dimension
147+
148+
if self.transform:
149+
image = self.transform(image)
150+
151+
return image, label

utils/load_data.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
from torch.utils.data import Dataset
22

3-
from .dataloaders import USPSDataset0_6, USPSH5_Digit_7_9_Dataset
3+
from .dataloaders import (MNISTDataset0_3, USPSDataset0_6,
4+
USPSH5_Digit_7_9_Dataset)
45

56

67
def load_data(dataset: str, *args, **kwargs) -> Dataset:
78
match dataset.lower():
89
case "usps_0-6":
910
return USPSDataset0_6(*args, **kwargs)
11+
case "mnist_0-3":
12+
return MNISTDataset0_3(*args, **kwargs)
1013
case "usps_7-9":
11-
return USPSH5_Digit_7_9_Dataset(*args, **kwargs)
14+
return USPSH5_Digit_7_9_Dataset(*args, **kwargs)
1215
case _:
1316
raise ValueError(f"Dataset: {dataset} not implemented.")

utils/load_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch.nn as nn
22

3-
from .models import ChristianModel, MagnusModel, SolveigModel
3+
from .models import ChristianModel, JanModel, MagnusModel, SolveigModel
44

55

66
def load_model(modelname: str, *args, **kwargs) -> nn.Module:
@@ -9,6 +9,8 @@ def load_model(modelname: str, *args, **kwargs) -> nn.Module:
99
return MagnusModel(*args, **kwargs)
1010
case "christianmodel":
1111
return ChristianModel(*args, **kwargs)
12+
case "janmodel":
13+
return JanModel(*args, **kwargs)
1214
case "solveigmodel":
1315
return SolveigModel(*args, **kwargs)
1416
case _:

utils/metrics/F1.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,3 @@ def compute(self):
8484
)
8585

8686
return f1_score
87-

utils/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
__all__ = ["MagnusModel", "ChristianModel", "SolveigModel"]
1+
__all__ = ["MagnusModel", "ChristianModel", "JanModel", "SolveigModel"]
22

33
from .christian_model import ChristianModel
4+
from .jan_model import JanModel
45
from .magnus_model import MagnusModel
56
from .solveig_model import SolveigModel

utils/models/christian_model.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ class ChristianModel(nn.Module):
2727
2828
Args
2929
----
30-
in_channels : int
31-
Number of input channels.
30+
image_shape : tuple(int, int, int)
31+
Shape of the input image (C, H, W).
3232
num_classes : int
3333
Number of classes in the dataset.
3434
@@ -49,10 +49,12 @@ class ChristianModel(nn.Module):
4949
FC Output Shape: (5, num_classes)
5050
"""
5151

52-
def __init__(self, in_channels, num_classes):
52+
def __init__(self, image_shape, num_classes):
5353
super().__init__()
5454

55-
self.cnn1 = CNNBlock(in_channels, 50)
55+
C, *_ = image_shape
56+
57+
self.cnn1 = CNNBlock(C, 50)
5658
self.cnn2 = CNNBlock(50, 100)
5759

5860
self.fc1 = nn.Linear(100 * 4 * 4, num_classes)

0 commit comments

Comments
 (0)