Skip to content

Commit af1c6cc

Browse files
committed
Used ruff and isort to fix code structure
1 parent 8a29072 commit af1c6cc

File tree

11 files changed

+143
-129
lines changed

11 files changed

+143
-129
lines changed

main.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import numpy as np
22
import torch as th
33
import torch.nn as nn
4+
import wandb
45
from torch.utils.data import DataLoader
56
from torchvision import transforms
67
from tqdm import tqdm
78

8-
import wandb
99
from utils import MetricWrapper, createfolders, get_args, load_data, load_model
1010

1111

@@ -22,13 +22,13 @@ def main():
2222
------
2323
2424
"""
25-
25+
2626
args = get_args()
27-
27+
2828
createfolders(args.datafolder, args.resultfolder, args.modelfolder)
29-
29+
3030
device = args.device
31-
31+
3232
if args.dataset.lower() in ["usps_0-6", "uspsh5_7_9"]:
3333
augmentations = transforms.Compose(
3434
[
@@ -38,7 +38,7 @@ def main():
3838
)
3939
else:
4040
augmentations = transforms.Compose([transforms.ToTensor()])
41-
41+
4242
# Dataset
4343
traindata = load_data(
4444
args.dataset,
@@ -54,22 +54,22 @@ def main():
5454
download=args.download_data,
5555
transform=augmentations,
5656
)
57-
57+
5858
metrics = MetricWrapper(traindata.num_classes, *args.metric)
59-
59+
6060
# Find the shape of the data, if is 2D, add a channel dimension
6161
data_shape = traindata[0][0].shape
6262
if len(data_shape) == 2:
6363
data_shape = (1, *data_shape)
64-
64+
6565
# load model
6666
model = load_model(
6767
args.modelname,
6868
image_shape=data_shape,
6969
num_classes=traindata.num_classes,
7070
)
7171
model.to(device)
72-
72+
7373
trainloader = DataLoader(
7474
traindata,
7575
batch_size=args.batchsize,
@@ -113,11 +113,11 @@ def main():
113113

114114
# wandb.login(key=WANDB_API)
115115
wandb.init(
116-
entity="ColabCode-org",
117-
# entity="FYS-8805 Exam",
118-
project="Test",
119-
tags=[args.modelname, args.dataset]
120-
)
116+
entity="ColabCode-org",
117+
# entity="FYS-8805 Exam",
118+
project="Test",
119+
tags=[args.modelname, args.dataset],
120+
)
121121
wandb.watch(model)
122122
exit()
123123
for epoch in range(args.epoch):

tests/test_dataloaders.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from utils.dataloaders import USPSDataset0_6, SVHNDataset
1+
from utils.dataloaders import SVHNDataset, USPSDataset0_6
2+
23

34
def test_uspsdataset0_6():
45
from pathlib import Path
@@ -32,28 +33,23 @@ def test_uspsdataset0_6():
3233
assert data.shape == (1, 16, 16)
3334
assert all(target == np.array([0, 0, 0, 0, 0, 0, 1]))
3435

35-
36+
3637
def test_svhn_dataset():
3738
import os
3839
from tempfile import TemporaryDirectory
40+
3941
from torchvision import transforms
40-
41-
with TemporaryDirectory() as tempdir:
42-
43-
trans = transforms.Compose([
44-
transforms.Resize((28,28)),
45-
transforms.ToTensor()
46-
])
47-
48-
dataset = SVHNDataset(tempdir,
49-
train=True,
50-
transform=trans,
51-
download=True,
52-
nr_channels=1)
53-
42+
43+
with TemporaryDirectory() as tempdir:
44+
trans = transforms.Compose([transforms.Resize((28, 28)), transforms.ToTensor()])
45+
46+
dataset = SVHNDataset(
47+
tempdir, train=True, transform=trans, download=True, nr_channels=1
48+
)
49+
5450
assert dataset.__len__() != 0
55-
assert os.path.exists(os.path.join(tempdir, 'train_32x32.mat'))
56-
51+
assert os.path.exists(os.path.join(tempdir, "train_32x32.mat"))
52+
5753
img, label = dataset.__getitem__(0)
58-
assert len(img.size()) == 3 and img.size() == (1,28,28) and img.size(0) == 1
54+
assert len(img.size()) == 3 and img.size() == (1, 28, 28) and img.size(0) == 1
5955
assert len(label.size()) == 1

tests/test_metrics.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from utils.metrics import Accuracy, F1Score, Precision, Recall, EntropyPrediction
1+
from utils.metrics import Accuracy, EntropyPrediction, F1Score, Precision, Recall
22

33

44
def test_recall():
@@ -98,20 +98,25 @@ def test_accuracy():
9898
f"Accuracy Score: {accuracy_score.item()}"
9999
)
100100

101+
101102
def test_entropypred():
102-
import torch as th
103-
104-
metric = EntropyPrediction(averages='mean')
103+
import torch as th
104+
105+
metric = EntropyPrediction(averages="mean")
105106

106-
true_lab = th.Tensor([0,1,1,2,4,3]).reshape(6,1).type(th.LongTensor)
107+
true_lab = th.Tensor([0, 1, 1, 2, 4, 3]).reshape(6, 1).type(th.LongTensor)
107108
pred_logits = th.nn.functional.one_hot(true_lab, 5)
108-
109-
#Test for log(0) errors and expected output
109+
110+
# Test for log(0) errors and expected output
110111
assert th.abs((th.sum(metric(true_lab, pred_logits)) - 0.0)) < 1e-5
111-
112-
pred_logits = th.rand(6,5)
113-
metric2 = EntropyPrediction(averages='sum')
114-
115-
#Test for averaging metric consistency
116-
assert th.abs(th.sum(6*metric(true_lab, pred_logits) - metric2(true_lab, pred_logits))) < 1e-5
117-
112+
113+
pred_logits = th.rand(6, 5)
114+
metric2 = EntropyPrediction(averages="sum")
115+
116+
# Test for averaging metric consistency
117+
assert (
118+
th.abs(
119+
th.sum(6 * metric(true_lab, pred_logits) - metric2(true_lab, pred_logits))
120+
)
121+
< 1e-5
122+
)

tests/test_models.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,15 @@ def test_jan_model(image_shape, num_classes):
3434
assert y.shape == (n, num_classes), f"Shape: {y.shape}"
3535

3636

37-
@pytest.mark.parameterize(
38-
"image_shape",
39-
[(3,28,28)]
40-
)
37+
@pytest.mark.parameterize("image_shape", [(3, 28, 28)])
4138
def test_magnus_model(image_shape):
42-
import torch as th
43-
39+
import torch as th
40+
4441
n, c, h, w = 5, *image_shape
45-
model = MagnusModel([h,w], 10, c)
46-
42+
model = MagnusModel([h, w], 10, c)
43+
4744
x = th.rand((n, c, h, w))
4845
with th.no_grad():
4946
y = model(x)
50-
47+
5148
assert y.shape == (n, 10), f"Shape: {y.shape}"
52-

utils/arg_parser.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,20 +69,15 @@ def get_args():
6969
nargs="+",
7070
help="Which metric to use for evaluation",
7171
)
72-
73-
parser.add_argument(
74-
'--imagesize',
75-
type=int,
76-
default=28,
77-
help='Imagesize'
78-
)
79-
72+
73+
parser.add_argument("--imagesize", type=int, default=28, help="Imagesize")
74+
8075
parser.add_argument(
81-
'--nr_channels',
76+
"--nr_channels",
8277
type=int,
8378
default=1,
84-
choices=[1,3],
85-
help='Number of image channels'
79+
choices=[1, 3],
80+
help="Number of image channels",
8681
)
8782

8883
# Training specific values

utils/dataloaders/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1-
__all__ = ["USPSDataset0_6", "USPSH5_Digit_7_9_Dataset", "MNISTDataset0_3", "SVHNDataset"]
1+
__all__ = [
2+
"USPSDataset0_6",
3+
"USPSH5_Digit_7_9_Dataset",
4+
"MNISTDataset0_3",
5+
"SVHNDataset",
6+
]
27

38
from .mnist_0_3 import MNISTDataset0_3
9+
from .svhn import SVHNDataset
410
from .usps_0_6 import USPSDataset0_6
511
from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset
6-
from .svhn import SVHNDataset

utils/dataloaders/svhn.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
23
import numpy as np
34
from scipy.io import loadmat
45
from torch.utils.data import Dataset
@@ -7,13 +8,13 @@
78

89
class SVHNDataset(Dataset):
910
def __init__(
10-
self,
11-
data_path: str,
11+
self,
12+
data_path: str,
1213
train: bool,
13-
transform=None,
14-
download:bool=True,
15-
nr_channels=3
16-
):
14+
transform=None,
15+
download: bool = True,
16+
nr_channels=3,
17+
):
1718
"""
1819
Initializes the SVHNDataset object.
1920
Args:
@@ -26,8 +27,8 @@ def __init__(
2627
"""
2728
super().__init__()
2829
# assert split == "train" or split == "test"
29-
self.split = 'train' if train else 'test'
30-
30+
self.split = "train" if train else "test"
31+
3132
if download:
3233
self._download_data(data_path)
3334

@@ -37,7 +38,7 @@ def __init__(
3738
self.images = data["X"].transpose(3, 1, 0, 2)
3839
self.labels = data["y"].flatten()
3940
self.labels[self.labels == 10] = 0
40-
41+
4142
self.nr_channels = nr_channels
4243
self.transforms = transform
4344
self.num_classes = len(np.unique(self.labels))
@@ -50,7 +51,7 @@ def _download_data(self, path: str):
5051
split (str): The dataset split to download, either 'train' or 'test'.
5152
"""
5253
print(f"Downloading SVHN data into {path}")
53-
54+
5455
SVHN(path, split=self.split, download=True)
5556

5657
def __len__(self):
@@ -73,7 +74,7 @@ def __getitem__(self, index):
7374

7475
if self.nr_channels == 1:
7576
img = np.mean(img, axis=2, keepdims=True)
76-
77+
7778
if self.transforms is not None:
7879
img = self.transforms(img)
7980

utils/load_data.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
from torch.utils.data import Dataset
22

3-
from .dataloaders import MNISTDataset0_3, USPSDataset0_6, USPSH5_Digit_7_9_Dataset, SVHNDataset
3+
from .dataloaders import (
4+
MNISTDataset0_3,
5+
SVHNDataset,
6+
USPSDataset0_6,
7+
USPSH5_Digit_7_9_Dataset,
8+
)
49

510

611
def load_data(dataset: str, *args, **kwargs) -> Dataset:

utils/load_metric.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import copy
2+
23
import numpy as np
34
import torch.nn as nn
5+
46
from .metrics import Accuracy, EntropyPrediction, F1Score, Precision, Recall
57

68

@@ -70,7 +72,7 @@ def _get_metric(self, key):
7072

7173
match key.lower():
7274
case "entropy":
73-
#Not dependent on knowing the number of classes
75+
# Not dependent on knowing the number of classes
7476
return EntropyPrediction()
7577
case "f1":
7678
return F1Score(num_classes=self.num_classes)

0 commit comments

Comments
 (0)