Skip to content

Commit 6e0c345

Browse files
committed
Ruff and Isort
1 parent a281d98 commit 6e0c345

File tree

6 files changed

+19
-18
lines changed

6 files changed

+19
-18
lines changed

main.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,7 @@ def main():
3030

3131
device = args.device
3232

33-
3433
if "usps" in args.dataset.lower():
35-
3634
transform = transforms.Compose(
3735
[
3836
transforms.Resize((28, 28)),
@@ -47,7 +45,6 @@ def main():
4745
data_dir=args.datafolder,
4846
transform=transform,
4947
val_size=args.val_size,
50-
5148
)
5249

5350
train_metrics = MetricWrapper(
@@ -129,7 +126,6 @@ def main():
129126
project=args.run_name,
130127
tags=[args.modelname, args.dataset],
131128
config=args,
132-
133129
)
134130
wandb.watch(model)
135131

tests/test_models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import pytest
22
import torch
33

4-
54
from utils.models import ChristianModel, JanModel, MagnusModel, SolveigModel
65

76

utils/arg_parser.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ def get_args():
3333
help="Whether model should be saved or not.",
3434
)
3535

36-
3736
# Data/Model specific values
3837
parser.add_argument(
3938
"--modelname",
@@ -83,7 +82,6 @@ def get_args():
8382
"--macro_averaging",
8483
action="store_true",
8584
help="If the flag is included, the metrics will be calculated using macro averaging.",
86-
8785
)
8886

8987
# Training specific values

utils/dataloaders/svhn.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import os
22

3-
43
import h5py
54
import numpy as np
65
from PIL import Image
@@ -95,7 +94,6 @@ def __getitem__(self, index):
9594
img = Image.fromarray(h5f["images"][index])
9695

9796
if self.nr_channels == 1:
98-
9997
img = img.convert("L")
10098
if self.transforms is not None:
10199
img = self.transforms(img)

utils/dataloaders/uspsh5_7_9.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
from pathlib import Path
2+
13
import h5py
24
import numpy as np
35
import torch
46
from PIL import Image
57
from torch.utils.data import Dataset
68
from torchvision import transforms
7-
from pathlib import Path
89

910

1011
class USPSH5_Digit_7_9_Dataset(Dataset):
@@ -31,7 +32,7 @@ class USPSH5_Digit_7_9_Dataset(Dataset):
3132
A transform function to apply to the images.
3233
"""
3334

34-
def __init__(self, data_path, train = False, transform=None):
35+
def __init__(self, data_path, train=False, transform=None):
3536
super().__init__()
3637
"""
3738
Initializes the USPS dataset by loading images and labels from the given `.h5` file.
@@ -108,7 +109,7 @@ def main():
108109
# Load the dataset
109110
dataset = USPSH5_Digit_7_9_Dataset(
110111
data_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git",
111-
train = False,
112+
train=False,
112113
transform=transform,
113114
)
114115
data_loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)

utils/metrics/F1.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,9 @@ def _micro_F1(self):
7676
precision = tp / (tp + fp + 1e-8) # Avoid division by zero
7777
recall = tp / (tp + fn + 1e-8) # Avoid division by zero
7878

79-
f1 = 2 * precision * recall / (precision + recall + 1e-8) # Avoid division by zero
79+
f1 = (
80+
2 * precision * recall / (precision + recall + 1e-8)
81+
) # Avoid division by zero
8082
return f1
8183

8284
def _macro_F1(self):
@@ -91,10 +93,18 @@ def _macro_F1(self):
9193
torch.Tensor
9294
The macro-averaged F1 score.
9395
"""
94-
precision_per_class = self.tp / (self.tp + self.fp + 1e-8) # Avoid division by zero
95-
recall_per_class = self.tp / (self.tp + self.fn + 1e-8) # Avoid division by zero
96-
f1_per_class = 2 * precision_per_class * recall_per_class / (
97-
precision_per_class + recall_per_class + 1e-8) # Avoid division by zero
96+
precision_per_class = self.tp / (
97+
self.tp + self.fp + 1e-8
98+
) # Avoid division by zero
99+
recall_per_class = self.tp / (
100+
self.tp + self.fn + 1e-8
101+
) # Avoid division by zero
102+
f1_per_class = (
103+
2
104+
* precision_per_class
105+
* recall_per_class
106+
/ (precision_per_class + recall_per_class + 1e-8)
107+
) # Avoid division by zero
98108

99109
# Take the average of F1 scores across all classes
100110
f1_score = torch.mean(f1_per_class)
@@ -138,4 +148,3 @@ def forward(self, preds, target):
138148
f1_score = self._micro_F1()
139149

140150
return f1_score
141-

0 commit comments

Comments
 (0)