Skip to content

Commit 2ac02eb

Browse files
authored
Merge pull request #59 from SFI-Visual-Intelligence/mag-branch
Fixed some tampering with my part of the code
2 parents cf8da2a + 2376196 commit 2ac02eb

File tree

6 files changed

+35
-30
lines changed

6 files changed

+35
-30
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ wandb_api.py
1010

1111
#Magnus specific
1212
docker/*
13+
job*
1314

1415
# Byte-compiled / optimized / DLL files
1516
__pycache__/

main.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from pathlib import Path
2-
31
import numpy as np
42
import torch as th
53
import torch.nn as nn
@@ -106,18 +104,22 @@ def main():
106104
optimizer.step()
107105
optimizer.zero_grad(set_to_none=True)
108106

109-
preds = th.argmax(logits, dim=1)
110-
metrics(y, preds)
107+
metrics(y, logits)
111108

112109
break
113110
print(metrics.accumulate())
114111
print("Dry run completed successfully.")
115-
exit(0)
116-
117-
wandb.login(key=WANDB_API)
118-
wandb.init(entity="ColabCode", project="Jan", tags=[args.modelname, args.dataset])
112+
exit()
113+
114+
# wandb.login(key=WANDB_API)
115+
wandb.init(
116+
entity="ColabCode-org",
117+
# entity="FYS-8805 Exam",
118+
project="Test",
119+
tags=[args.modelname, args.dataset]
120+
)
119121
wandb.watch(model)
120-
122+
exit()
121123
for epoch in range(args.epoch):
122124
# Training loop start
123125
trainingloss = []
@@ -133,8 +135,7 @@ def main():
133135
optimizer.zero_grad(set_to_none=True)
134136
trainingloss.append(loss.item())
135137

136-
preds = th.argmax(logits, dim=1)
137-
metrics(y, preds)
138+
metrics(y, logits)
138139

139140
wandb.log(metrics.accumulate(str_prefix="Train "))
140141
metrics.reset()
@@ -149,8 +150,7 @@ def main():
149150
loss = criterion(logits, y)
150151
evalloss.append(loss.item())
151152

152-
preds = th.argmax(logits, dim=1)
153-
metrics(y, preds)
153+
metrics(y, logits)
154154

155155
wandb.log(metrics.accumulate(str_prefix="Evaluation "))
156156
metrics.reset()

utils/dataloaders/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
__all__ = ["USPSDataset0_6", "USPSH5_Digit_7_9_Dataset", "MNISTDataset0_3"]
1+
__all__ = ["USPSDataset0_6", "USPSH5_Digit_7_9_Dataset", "MNISTDataset0_3", "SVHNDataset"]
22

33
from .mnist_0_3 import MNISTDataset0_3
44
from .usps_0_6 import USPSDataset0_6
55
from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset
6+
from .svhn import SVHNDataset

utils/dataloaders/svhn.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,11 @@
77

88
class SVHNDataset(Dataset):
99
def __init__(
10-
self, datapath: str, transforms=None, download_data=True, split="train"
11-
):
10+
self, datapath: str,
11+
transforms=None,
12+
download_data=True,
13+
split="train"
14+
):
1215
"""
1316
Initializes the SVHNDataset object.
1417
Args:

utils/load_data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from torch.utils.data import Dataset
22

3-
from .dataloaders import MNISTDataset0_3, USPSDataset0_6, USPSH5_Digit_7_9_Dataset
3+
from .dataloaders import MNISTDataset0_3, USPSDataset0_6, USPSH5_Digit_7_9_Dataset, SVHNDataset
44

55

66
def load_data(dataset: str, *args, **kwargs) -> Dataset:
@@ -41,7 +41,7 @@ def load_data(dataset: str, *args, **kwargs) -> Dataset:
4141
case "usps_7-9":
4242
return USPSH5_Digit_7_9_Dataset(*args, **kwargs)
4343
case "svhn":
44-
raise NotImplementedError("SVHN dataset not yet implemented.")
44+
raise SVHNDataset(*args, **kwargs)
4545
case "mnist_4-9":
4646
raise NotImplementedError("MNIST 4-9 dataset not yet implemented.")
4747
case _:

utils/models/magnus_model.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,18 @@ def __init__(self, imagesize: int, imagechannels: int, n_classes: int = 10):
2222
self.imagesize = imagesize
2323
self.imagechannels = imagechannels
2424

25-
self.layer1 = nn.Sequential(
26-
*(
27-
[
28-
nn.Linear(
29-
self.imagechannels * self.imagesize * self.imagesize, 133
30-
),
31-
nn.ReLU(),
32-
]
33-
)
34-
)
35-
self.layer2 = nn.Sequential(*([nn.Linear(133, 133), nn.ReLU()]))
36-
self.layer3 = nn.Sequential(*([nn.Linear(133, n_classes), nn.ReLU()]))
25+
self.layer1 = nn.Sequential(*([
26+
nn.Linear(self.imagechannels * self.imagesize * self.imagesize, 133),
27+
nn.ReLU(),
28+
]))
29+
self.layer2 = nn.Sequential(*([
30+
nn.Linear(133, 133),
31+
nn.ReLU()
32+
]))
33+
self.layer3 = nn.Sequential(*([
34+
nn.Linear(133, n_classes),
35+
nn.ReLU()
36+
]))
3737

3838
def forward(self, x):
3939
"""

0 commit comments

Comments
 (0)