Skip to content

Commit 234b7f6

Browse files
committed
merged main into solveig-branch
2 parents 4f8725c + 75b1801 commit 234b7f6

File tree

12 files changed

+216
-39
lines changed

12 files changed

+216
-39
lines changed

.gitignore

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
__pycache__/
22
.ipynb_checkpoints/
3-
Data/
4-
Results/
5-
Experiments/
3+
Data/*
4+
Results/*
5+
Experiments/*
66
_build/
7-
bin/
8-
wandb/
7+
bin/*
8+
wandb/*
99
wandb_api.py
1010

11+
#Magnus specific
12+
docker/*
13+
job*
14+
1115
# Byte-compiled / optimized / DLL files
1216
__pycache__/
1317
*.py[cod]

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ dependencies:
1919
- ruff
2020
- scalene
2121
- tqdm
22+
- scipy
2223
- pip:
2324
- torch
2425
- torchvision

main.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
from pathlib import Path
2-
31
import numpy as np
42
import torch as th
53
import torch.nn as nn
6-
import wandb
74
from torch.utils.data import DataLoader
85
from torchvision import transforms
96
from tqdm import tqdm
107

8+
import wandb
119
from utils import MetricWrapper, createfolders, get_args, load_data, load_model
1210

1311

@@ -24,6 +22,7 @@ def main():
2422
------
2523
2624
"""
25+
2726
args = get_args()
2827

2928
createfolders(args.datafolder, args.resultfolder, args.modelfolder)
@@ -105,18 +104,22 @@ def main():
105104
optimizer.step()
106105
optimizer.zero_grad(set_to_none=True)
107106

108-
preds = th.argmax(logits, dim=1)
109-
metrics(y, preds)
107+
metrics(y, logits)
110108

111109
break
112110
print(metrics.accumulate())
113111
print("Dry run completed successfully.")
114-
exit(0)
115-
116-
wandb.login(key=WANDB_API)
117-
wandb.init(entity="ColabCode", project="Jan", tags=[args.modelname, args.dataset])
112+
exit()
113+
114+
# wandb.login(key=WANDB_API)
115+
wandb.init(
116+
entity="ColabCode-org",
117+
# entity="FYS-8805 Exam",
118+
project="Test",
119+
tags=[args.modelname, args.dataset]
120+
)
118121
wandb.watch(model)
119-
122+
exit()
120123
for epoch in range(args.epoch):
121124
# Training loop start
122125
trainingloss = []
@@ -132,8 +135,7 @@ def main():
132135
optimizer.zero_grad(set_to_none=True)
133136
trainingloss.append(loss.item())
134137

135-
preds = th.argmax(logits, dim=1)
136-
metrics(y, preds)
138+
metrics(y, logits)
137139

138140
wandb.log(metrics.accumulate(str_prefix="Train "))
139141
metrics.reset()
@@ -148,8 +150,7 @@ def main():
148150
loss = criterion(logits, y)
149151
evalloss.append(loss.item())
150152

151-
preds = th.argmax(logits, dim=1)
152-
metrics(y, preds)
153+
metrics(y, logits)
153154

154155
wandb.log(metrics.accumulate(str_prefix="Evaluation "))
155156
metrics.reset()

utils/arg_parser.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ def get_args():
3535

3636
parser.add_argument(
3737
"--download-data",
38-
action="store_true",
38+
type=bool,
39+
default=False,
3940
help="Whether the data should be downloaded or not. Might cause code to start a bit slowly.",
4041
)
4142

@@ -44,14 +45,20 @@ def get_args():
4445
"--modelname",
4546
type=str,
4647
default="MagnusModel",
47-
choices=["MagnusModel", "ChristianModel", "SolveigModel", "JanModel"],
48+
choices=[
49+
"MagnusModel",
50+
"ChristianModel",
51+
"SolveigModel",
52+
"JanModel",
53+
"JohanModel",
54+
],
4855
help="Model which to be trained on",
4956
)
5057
parser.add_argument(
5158
"--dataset",
5259
type=str,
5360
default="svhn",
54-
choices=["svhn", "usps_0-6", "uspsh5_7_9", "mnist_0-3"],
61+
choices=["svhn", "usps_0-6", "usps_7-9", "mnist_0-3", "mnist_4-9"],
5562
help="Which dataset to train the model on.",
5663
)
5764

@@ -63,6 +70,21 @@ def get_args():
6370
nargs="+",
6471
help="Which metric to use for evaluation",
6572
)
73+
74+
parser.add_argument(
75+
'--imagesize',
76+
type=int,
77+
default=28,
78+
help='Imagesize'
79+
)
80+
81+
parser.add_argument(
82+
'--nr_channels',
83+
type=int,
84+
default=1,
85+
choices=[1,3],
86+
help='Number of image channels'
87+
)
6688

6789
# Training specific values
6890
parser.add_argument(
@@ -95,4 +117,10 @@ def get_args():
95117
action="store_true",
96118
help="If true, the code will not run the training loop.",
97119
)
98-
return parser.parse_args()
120+
args = parser.parse_args()
121+
122+
assert args.epoch > 0, "Epoch should be a positive integer."
123+
assert args.learning_rate > 0, "Learning rate should be a positive float."
124+
assert args.batchsize > 0, "Batch size should be a positive integer."
125+
126+
return args

utils/dataloaders/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
__all__ = ["USPSDataset0_6", "USPSH5_Digit_7_9_Dataset", "MNISTDataset0_3"]
1+
__all__ = ["USPSDataset0_6", "USPSH5_Digit_7_9_Dataset", "MNISTDataset0_3", "SVHNDataset"]
22

33
from .mnist_0_3 import MNISTDataset0_3
44
from .usps_0_6 import USPSDataset0_6
55
from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset
6+
from .svhn import SVHNDataset

utils/dataloaders/svhn.py

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,79 @@
1+
import os
2+
import numpy as np
3+
from scipy.io import loadmat
14
from torch.utils.data import Dataset
5+
from torchvision.datasets import SVHN
26

37

4-
class SVHN(Dataset):
5-
def __init__(self):
8+
class SVHNDataset(Dataset):
9+
def __init__(
10+
self,
11+
data_path: str,
12+
train: bool,
13+
transform=None,
14+
download:bool=True,
15+
nr_channels=3
16+
):
17+
"""
18+
Initializes the SVHNDataset object.
19+
Args:
20+
data_path (str): Path to where the data lies. If download_data is set to True, this is where the data will be downloaded.
21+
transforms: Torch composite of transformations which are to be applied to the dataset images.
22+
download_data (bool): If True, downloads the dataset to the specified data_path.
23+
split (str): The dataset split to use, either 'train' or 'test'.
24+
Raises:
25+
AssertionError: If the split is not 'train' or 'test'.
26+
"""
627
super().__init__()
28+
# assert split == "train" or split == "test"
29+
self.split = 'train' if train else 'test'
30+
31+
if download:
32+
self._download_data(data_path)
33+
34+
data = loadmat(os.path.join(data_path, f"{self.split}_32x32.mat"))
35+
36+
# Images on the form N x H x W x C
37+
self.images = data["X"].transpose(3, 1, 0, 2)
38+
self.labels = data["y"].flatten()
39+
self.labels[self.labels == 10] = 0
40+
41+
self.nr_channels = nr_channels
42+
self.transforms = transform
43+
44+
def _download_data(self, path: str):
45+
"""
46+
Downloads the SVHN dataset.
47+
Args:
48+
path (str): The directory where the dataset will be downloaded.
49+
split (str): The dataset split to download, either 'train' or 'test'.
50+
"""
51+
print(f"Downloading SVHN data into {path}")
52+
53+
SVHN(path, split=self.split, download=True)
754

855
def __len__(self):
9-
return
56+
"""
57+
Returns the number of samples in the dataset.
58+
Returns:
59+
int: The number of samples.
60+
"""
61+
return len(self.labels)
1062

1163
def __getitem__(self, index):
12-
return
64+
"""
65+
Retrieves the image and label at the specified index.
66+
Args:
67+
index (int): The index of the sample to retrieve.
68+
Returns:
69+
tuple: A tuple containing the image and its corresponding label.
70+
"""
71+
img, lab = self.images[index], self.labels[index]
72+
73+
if self.nr_channels == 1:
74+
img = np.mean(img, axis=2, keepdims=True)
75+
76+
if self.transforms is not None:
77+
img = self.transforms(img)
78+
79+
return img, lab

utils/load_data.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from torch.utils.data import Dataset
22

3-
from .dataloaders import MNISTDataset0_3, USPSDataset0_6, USPSH5_Digit_7_9_Dataset
3+
from .dataloaders import MNISTDataset0_3, USPSDataset0_6, USPSH5_Digit_7_9_Dataset, SVHNDataset
44

55

66
def load_data(dataset: str, *args, **kwargs) -> Dataset:
@@ -40,5 +40,9 @@ def load_data(dataset: str, *args, **kwargs) -> Dataset:
4040
return MNISTDataset0_3(*args, **kwargs)
4141
case "usps_7-9":
4242
return USPSH5_Digit_7_9_Dataset(*args, **kwargs)
43+
case "svhn":
44+
return SVHNDataset(*args, **kwargs)
45+
case "mnist_4-9":
46+
raise NotImplementedError("MNIST 4-9 dataset not yet implemented.")
4347
case _:
4448
raise NotImplementedError(f"Dataset: {dataset} not implemented.")

utils/load_metric.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88

99
class MetricWrapper(nn.Module):
10-
1110
"""
1211
Wrapper class for metrics, that runs multiple metrics on the same data.
1312
@@ -46,9 +45,7 @@ class MetricWrapper(nn.Module):
4645
{'entropy': [], 'f1': [], 'precision': []}
4746
"""
4847

49-
5048
def __init__(self, *metrics, num_classes):
51-
5249
super().__init__()
5350
self.metrics = {}
5451
self.num_classes = num_classes

utils/load_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch.nn as nn
22

3-
from .models import ChristianModel, JanModel, MagnusModel, SolveigModel
3+
from .models import ChristianModel, JanModel, JohanModel, MagnusModel, SolveigModel
44

55

66
def load_model(modelname: str, *args, **kwargs) -> nn.Module:
@@ -44,6 +44,8 @@ def load_model(modelname: str, *args, **kwargs) -> nn.Module:
4444
return JanModel(*args, **kwargs)
4545
case "solveigmodel":
4646
return SolveigModel(*args, **kwargs)
47+
case "johanmodel":
48+
return JohanModel(*args, **kwargs)
4749
case _:
4850
errmsg = (
4951
f"Model: {modelname} not implemented. "

utils/metrics/EntropyPred.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,31 @@
11
import torch.nn as nn
2+
from scipy.stats import entropy
23

34

45
class EntropyPrediction(nn.Module):
5-
def __init__(self):
6+
def __init__(self, averages: str = "average"):
7+
"""
8+
Initializes the EntropyPrediction module.
9+
Args:
10+
averages (str): Specifies the method of aggregation for entropy values.
11+
Must be either 'average' or 'sum'.
12+
Raises:
13+
AssertionError: If the averages parameter is not 'average' or 'sum'.
14+
"""
615
super().__init__()
716

8-
def __call__(self, y_true, y_false):
9-
return
17+
assert averages == "average" or averages == "sum"
18+
self.averages = averages
19+
self.stored_entropy_values = []
20+
21+
def __call__(self, y_true, y_false_logits):
22+
"""
23+
Computes the entropy between true labels and predicted logits, storing the results.
24+
Args:
25+
y_true: The true labels.
26+
y_false_logits: The predicted logits.
27+
Side Effects:
28+
Appends the computed entropy values to the stored_entropy_values list.
29+
"""
30+
entropy_values = entropy(y_true, qk=y_false_logits)
31+
return entropy_values

0 commit comments

Comments
 (0)