Skip to content

Commit ba2212e

Browse files
authored
Merge branch 'main' into Jan-dataloader
2 parents f2e14c4 + 75b1801 commit ba2212e

File tree

11 files changed

+176
-48
lines changed

11 files changed

+176
-48
lines changed

.gitignore

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
__pycache__/
22
.ipynb_checkpoints/
3-
Data/
4-
Results/
5-
Experiments/
3+
Data/*
4+
Results/*
5+
Experiments/*
66
_build/
7-
bin/
8-
wandb/
7+
bin/*
8+
wandb/*
99
wandb_api.py
1010

1111
#Magnus specific
1212
docker/*
13+
job*
1314

1415
# Byte-compiled / optimized / DLL files
1516
__pycache__/

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ dependencies:
1919
- ruff
2020
- scalene
2121
- tqdm
22+
- scipy
2223
- pip:
2324
- torch
2425
- torchvision

main.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import numpy as np
22
import torch as th
33
import torch.nn as nn
4-
import wandb
54
from torch.utils.data import DataLoader
65
from torchvision import transforms
76
from tqdm import tqdm
87

8+
import wandb
99
from utils import MetricWrapper, createfolders, get_args, load_data, load_model
1010

1111

@@ -98,18 +98,22 @@ def main():
9898
optimizer.step()
9999
optimizer.zero_grad(set_to_none=True)
100100

101-
preds = th.argmax(logits, dim=1)
102-
metrics(y, preds)
101+
metrics(y, logits)
103102

104103
break
105104
print(metrics.accumulate())
106105
print("Dry run completed successfully.")
107-
exit(0)
108-
109-
wandb.login(key=WANDB_API)
110-
wandb.init(entity="ColabCode", project="Jan", tags=[args.modelname, args.dataset])
106+
exit()
107+
108+
# wandb.login(key=WANDB_API)
109+
wandb.init(
110+
entity="ColabCode-org",
111+
# entity="FYS-8805 Exam",
112+
project="Test",
113+
tags=[args.modelname, args.dataset]
114+
)
111115
wandb.watch(model)
112-
116+
exit()
113117
for epoch in range(args.epoch):
114118
# Training loop start
115119
trainingloss = []
@@ -125,8 +129,7 @@ def main():
125129
optimizer.zero_grad(set_to_none=True)
126130
trainingloss.append(loss.item())
127131

128-
preds = th.argmax(logits, dim=1)
129-
metrics(y, preds)
132+
metrics(y, logits)
130133

131134
wandb.log(metrics.accumulate(str_prefix="Train "))
132135
metrics.reset()
@@ -141,8 +144,7 @@ def main():
141144
loss = criterion(logits, y)
142145
valloss.append(loss.item())
143146

144-
preds = th.argmax(logits, dim=1)
145-
metrics(y, preds)
147+
metrics(y, logits)
146148

147149
wandb.log(metrics.accumulate(str_prefix="Validation "))
148150
metrics.reset()

utils/arg_parser.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,20 @@ def get_args():
3838
"--modelname",
3939
type=str,
4040
default="MagnusModel",
41-
choices=["MagnusModel", "ChristianModel", "SolveigModel", "JanModel"],
41+
choices=[
42+
"MagnusModel",
43+
"ChristianModel",
44+
"SolveigModel",
45+
"JanModel",
46+
"JohanModel",
47+
],
4248
help="Model which to be trained on",
4349
)
4450
parser.add_argument(
4551
"--dataset",
4652
type=str,
4753
default="svhn",
48-
choices=["svhn", "usps_0-6", "uspsh5_7_9", "mnist_0-3"],
54+
choices=["svhn", "usps_0-6", "usps_7-9", "mnist_0-3", "mnist_4-9"],
4955
help="Which dataset to train the model on.",
5056
)
5157
parser.add_argument(
@@ -62,6 +68,21 @@ def get_args():
6268
nargs="+",
6369
help="Which metric to use for evaluation",
6470
)
71+
72+
parser.add_argument(
73+
'--imagesize',
74+
type=int,
75+
default=28,
76+
help='Imagesize'
77+
)
78+
79+
parser.add_argument(
80+
'--nr_channels',
81+
type=int,
82+
default=1,
83+
choices=[1,3],
84+
help='Number of image channels'
85+
)
6586

6687
# Training specific values
6788
parser.add_argument(
@@ -94,4 +115,10 @@ def get_args():
94115
action="store_true",
95116
help="If true, the code will not run the training loop.",
96117
)
97-
return parser.parse_args()
118+
args = parser.parse_args()
119+
120+
assert args.epoch > 0, "Epoch should be a positive integer."
121+
assert args.learning_rate > 0, "Learning rate should be a positive float."
122+
assert args.batchsize > 0, "Batch size should be a positive integer."
123+
124+
return args

utils/dataloaders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
"USPSH5_Digit_7_9_Dataset",
44
"MNISTDataset0_3",
55
"Downloader",
6+
"SVHNDataset",
67
]
78

89
from .download import Downloader
910
from .mnist_0_3 import MNISTDataset0_3
1011
from .usps_0_6 import USPSDataset0_6
1112
from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset
13+
from .svhn import SVHNDataset

utils/dataloaders/svhn.py

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,79 @@
1+
import os
2+
import numpy as np
3+
from scipy.io import loadmat
14
from torch.utils.data import Dataset
5+
from torchvision.datasets import SVHN
26

37

4-
class SVHN(Dataset):
5-
def __init__(self):
8+
class SVHNDataset(Dataset):
9+
def __init__(
10+
self,
11+
data_path: str,
12+
train: bool,
13+
transform=None,
14+
download:bool=True,
15+
nr_channels=3
16+
):
17+
"""
18+
Initializes the SVHNDataset object.
19+
Args:
20+
data_path (str): Path to where the data lies. If download_data is set to True, this is where the data will be downloaded.
21+
transforms: Torch composite of transformations which are to be applied to the dataset images.
22+
download_data (bool): If True, downloads the dataset to the specified data_path.
23+
split (str): The dataset split to use, either 'train' or 'test'.
24+
Raises:
25+
AssertionError: If the split is not 'train' or 'test'.
26+
"""
627
super().__init__()
28+
# assert split == "train" or split == "test"
29+
self.split = 'train' if train else 'test'
30+
31+
if download:
32+
self._download_data(data_path)
33+
34+
data = loadmat(os.path.join(data_path, f"{self.split}_32x32.mat"))
35+
36+
# Images on the form N x H x W x C
37+
self.images = data["X"].transpose(3, 1, 0, 2)
38+
self.labels = data["y"].flatten()
39+
self.labels[self.labels == 10] = 0
40+
41+
self.nr_channels = nr_channels
42+
self.transforms = transform
43+
44+
def _download_data(self, path: str):
45+
"""
46+
Downloads the SVHN dataset.
47+
Args:
48+
path (str): The directory where the dataset will be downloaded.
49+
split (str): The dataset split to download, either 'train' or 'test'.
50+
"""
51+
print(f"Downloading SVHN data into {path}")
52+
53+
SVHN(path, split=self.split, download=True)
754

855
def __len__(self):
9-
return
56+
"""
57+
Returns the number of samples in the dataset.
58+
Returns:
59+
int: The number of samples.
60+
"""
61+
return len(self.labels)
1062

1163
def __getitem__(self, index):
12-
return
64+
"""
65+
Retrieves the image and label at the specified index.
66+
Args:
67+
index (int): The index of the sample to retrieve.
68+
Returns:
69+
tuple: A tuple containing the image and its corresponding label.
70+
"""
71+
img, lab = self.images[index], self.labels[index]
72+
73+
if self.nr_channels == 1:
74+
img = np.mean(img, axis=2, keepdims=True)
75+
76+
if self.transforms is not None:
77+
img = self.transforms(img)
78+
79+
return img, lab

utils/load_data.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
MNISTDataset0_3,
77
USPSDataset0_6,
88
USPSH5_Digit_7_9_Dataset,
9+
SVHNDataset,
910
)
1011

1112

@@ -59,6 +60,12 @@ def load_data(dataset: str, *args, **kwargs) -> tuple:
5960
dataset = MNISTDataset0_3
6061
train_labels, test_labels = downloader.mnist(data_dir=data_dir)
6162
labels = np.arange(4)
63+
case "svhn":
64+
dataset = SVHNDataset
65+
train_labels, test_labels = downloader.svhn(data_dir=data_dir)
66+
labels = np.arange(10)
67+
case "mnist_4-9":
68+
raise NotImplementedError("MNIST 4-9 dataset not yet implemented.")
6269
case _:
6370
raise NotImplementedError(f"Dataset: {dataset} not implemented.")
6471

utils/load_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch.nn as nn
22

3-
from .models import ChristianModel, JanModel, MagnusModel, SolveigModel
3+
from .models import ChristianModel, JanModel, JohanModel, MagnusModel, SolveigModel
44

55

66
def load_model(modelname: str, *args, **kwargs) -> nn.Module:
@@ -44,6 +44,8 @@ def load_model(modelname: str, *args, **kwargs) -> nn.Module:
4444
return JanModel(*args, **kwargs)
4545
case "solveigmodel":
4646
return SolveigModel(*args, **kwargs)
47+
case "johanmodel":
48+
return JohanModel(*args, **kwargs)
4749
case _:
4850
errmsg = (
4951
f"Model: {modelname} not implemented. "

utils/metrics/EntropyPred.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,31 @@
11
import torch.nn as nn
2+
from scipy.stats import entropy
23

34

45
class EntropyPrediction(nn.Module):
5-
def __init__(self):
6+
def __init__(self, averages: str = "average"):
7+
"""
8+
Initializes the EntropyPrediction module.
9+
Args:
10+
averages (str): Specifies the method of aggregation for entropy values.
11+
Must be either 'average' or 'sum'.
12+
Raises:
13+
AssertionError: If the averages parameter is not 'average' or 'sum'.
14+
"""
615
super().__init__()
716

8-
def __call__(self, y_true, y_false_logits):
9-
return
17+
assert averages == "average" or averages == "sum"
18+
self.averages = averages
19+
self.stored_entropy_values = []
1020

11-
def __reset__(self):
12-
pass
21+
def __call__(self, y_true, y_false_logits):
22+
"""
23+
Computes the entropy between true labels and predicted logits, storing the results.
24+
Args:
25+
y_true: The true labels.
26+
y_false_logits: The predicted logits.
27+
Side Effects:
28+
Appends the computed entropy values to the stored_entropy_values list.
29+
"""
30+
entropy_values = entropy(y_true, qk=y_false_logits)
31+
return entropy_values

utils/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
__all__ = ["MagnusModel", "ChristianModel", "JanModel", "SolveigModel"]
1+
__all__ = ["MagnusModel", "ChristianModel", "JanModel", "SolveigModel", "JohanModel"]
22

33
from .christian_model import ChristianModel
44
from .jan_model import JanModel
5+
from .johan_model import JohanModel
56
from .magnus_model import MagnusModel
67
from .solveig_model import SolveigModel

0 commit comments

Comments
 (0)