Skip to content

Commit dd99f22

Browse files
committed
Fixed some tampering with my own code.
1 parent a20b338 commit dd99f22

File tree

5 files changed

+22
-20
lines changed

5 files changed

+22
-20
lines changed

main.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from pathlib import Path
2-
31
import numpy as np
42
import torch as th
53
import torch.nn as nn
@@ -111,7 +109,7 @@ def main():
111109
break
112110
print(metrics.accumulate())
113111
print("Dry run completed successfully.")
114-
exit(0)
112+
exit()
115113

116114
# wandb.login(key=WANDB_API)
117115
wandb.init(

utils/dataloaders/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
__all__ = ["USPSDataset0_6", "USPSH5_Digit_7_9_Dataset", "MNISTDataset0_3"]
1+
__all__ = ["USPSDataset0_6", "USPSH5_Digit_7_9_Dataset", "MNISTDataset0_3", "SVHNDataset"]
22

33
from .mnist_0_3 import MNISTDataset0_3
44
from .usps_0_6 import USPSDataset0_6
55
from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset
6+
from .svhn import SVHNDataset

utils/dataloaders/svhn.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,11 @@
77

88
class SVHNDataset(Dataset):
99
def __init__(
10-
self, datapath: str, transforms=None, download_data=True, split="train"
11-
):
10+
self, datapath: str,
11+
transforms=None,
12+
download_data=True,
13+
split="train"
14+
):
1215
"""
1316
Initializes the SVHNDataset object.
1417
Args:

utils/load_data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from torch.utils.data import Dataset
22

3-
from .dataloaders import MNISTDataset0_3, USPSDataset0_6, USPSH5_Digit_7_9_Dataset
3+
from .dataloaders import MNISTDataset0_3, USPSDataset0_6, USPSH5_Digit_7_9_Dataset, SVHNDataset
44

55

66
def load_data(dataset: str, *args, **kwargs) -> Dataset:
@@ -41,7 +41,7 @@ def load_data(dataset: str, *args, **kwargs) -> Dataset:
4141
case "usps_7-9":
4242
return USPSH5_Digit_7_9_Dataset(*args, **kwargs)
4343
case "svhn":
44-
raise NotImplementedError("SVHN dataset not yet implemented.")
44+
raise SVHNDataset(*args, **kwargs)
4545
case "mnist_4-9":
4646
raise NotImplementedError("MNIST 4-9 dataset not yet implemented.")
4747
case _:

utils/models/magnus_model.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,18 @@ def __init__(self, imagesize: int, imagechannels: int, n_classes: int = 10):
2222
self.imagesize = imagesize
2323
self.imagechannels = imagechannels
2424

25-
self.layer1 = nn.Sequential(
26-
*(
27-
[
28-
nn.Linear(
29-
self.imagechannels * self.imagesize * self.imagesize, 133
30-
),
31-
nn.ReLU(),
32-
]
33-
)
34-
)
35-
self.layer2 = nn.Sequential(*([nn.Linear(133, 133), nn.ReLU()]))
36-
self.layer3 = nn.Sequential(*([nn.Linear(133, n_classes), nn.ReLU()]))
25+
self.layer1 = nn.Sequential(*([
26+
nn.Linear(self.imagechannels * self.imagesize * self.imagesize, 133),
27+
nn.ReLU(),
28+
]))
29+
self.layer2 = nn.Sequential(*([
30+
nn.Linear(133, 133),
31+
nn.ReLU()
32+
]))
33+
self.layer3 = nn.Sequential(*([
34+
nn.Linear(133, n_classes),
35+
nn.ReLU()
36+
]))
3737

3838
def forward(self, x):
3939
"""

0 commit comments

Comments
 (0)