Skip to content

Commit 9abaf0c

Browse files
committed
ruffed
1 parent 8d6c07a commit 9abaf0c

File tree

4 files changed

+61
-46
lines changed

4 files changed

+61
-46
lines changed

utils/dataloaders/mnist_4_9.py

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,53 +5,44 @@
55
import numpy as np
66
from torch.utils.data import Dataset
77

8+
89
class MNIST_4_9(Dataset):
9-
def __init__(self,
10-
datapath: Path,
11-
train: bool = False,
12-
download: bool = False
13-
):
10+
def __init__(self, datapath: Path, train: bool = False, download: bool = False):
1411
super.__init__()
1512
self.datapath = datapath
1613
self.mnist_path = self.datapath / "MNIST"
1714
self.train = train
1815
self.download = download
1916
self.num_classes: int = 6
20-
17+
2118
if not self.download and not self._already_downloaded():
2219
raise FileNotFoundError(
23-
'Data files are not found. Set --download-data=True to download the data'
20+
"Data files are not found. Set --download-data=True to download the data"
2421
)
2522
if self.download and not self._already_downloaded():
2623
self._download()
27-
28-
29-
30-
24+
3125
def _download(self):
3226
urls: dict = {
3327
"train_images": "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz",
3428
"train_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz",
3529
"test_images": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz",
3630
"test_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz",
3731
}
38-
39-
32+
4033
for url in urls.values():
41-
file_path: Path = os.path.join(self.mnist_path, url.split('/')[-1])
42-
file_name: Path = file_path.replace('.gz','')
34+
file_path: Path = os.path.join(self.mnist_path, url.split("/")[-1])
35+
file_name: Path = file_path.replace(".gz", "")
4336
if os.path.exists(file_name):
4437
print(f"File: {file_name} already downloaded")
4538
else:
4639
print(f"File: {file_name} is downloading...")
47-
ur.urlretrieve(url, file_path) # Download file
48-
with gzip.open(file_path, 'rb') as infile:
49-
with open(file_name, 'wb') as outfile:
50-
outfile.write(infile.read()) # Write from url to local file
51-
os.remove(file_path) # remove .gz file
52-
53-
54-
40+
ur.urlretrieve(url, file_path) # Download file
41+
with gzip.open(file_path, "rb") as infile:
42+
with open(file_name, "wb") as outfile:
43+
outfile.write(infile.read()) # Write from url to local file
44+
os.remove(file_path) # remove .gz file
45+
5546
def _already_downloaded(self):
5647
if self.mnist_path.exists():
5748
required_files: list = [
@@ -65,11 +56,9 @@ def _already_downloaded(self):
6556
else:
6657
self.mnist_path.mkdir(parents=True, exist_ok=True)
6758
return False
68-
59+
6960
def __len__(self):
7061
pass
71-
62+
7263
def __getitem__(self):
7364
pass
74-
75-

utils/load_metric.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,21 @@ def _get_metric(self, key):
7575
case "entropy":
7676
return EntropyPrediction(num_classes=self.num_classes)
7777
case "f1":
78-
return F1Score(num_classes=self.num_classes, macro_averaging=self.macro_averaging)
78+
return F1Score(
79+
num_classes=self.num_classes, macro_averaging=self.macro_averaging
80+
)
7981
case "recall":
80-
return Recall(num_classes=self.num_classes, macro_averaging=self.macro_averaging)
82+
return Recall(
83+
num_classes=self.num_classes, macro_averaging=self.macro_averaging
84+
)
8185
case "precision":
82-
return Precision(num_classes=self.num_classes, macro_averaging=self.macro_averaging)
86+
return Precision(
87+
num_classes=self.num_classes, macro_averaging=self.macro_averaging
88+
)
8389
case "accuracy":
84-
return Accuracy(num_classes=self.num_classes, macro_averaging=self.macro_averaging)
90+
return Accuracy(
91+
num_classes=self.num_classes, macro_averaging=self.macro_averaging
92+
)
8593
case _:
8694
raise ValueError(f"Metric {key} not supported")
8795

utils/metrics/F1.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,9 @@ def _micro_F1(self):
7676
precision = tp / (tp + fp + 1e-8) # Avoid division by zero
7777
recall = tp / (tp + fn + 1e-8) # Avoid division by zero
7878

79-
f1 = 2 * precision * recall / (precision + recall + 1e-8) # Avoid division by zero
79+
f1 = (
80+
2 * precision * recall / (precision + recall + 1e-8)
81+
) # Avoid division by zero
8082
return f1
8183

8284
def _macro_F1(self):
@@ -91,10 +93,18 @@ def _macro_F1(self):
9193
torch.Tensor
9294
The macro-averaged F1 score.
9395
"""
94-
precision_per_class = self.tp / (self.tp + self.fp + 1e-8) # Avoid division by zero
95-
recall_per_class = self.tp / (self.tp + self.fn + 1e-8) # Avoid division by zero
96-
f1_per_class = 2 * precision_per_class * recall_per_class / (
97-
precision_per_class + recall_per_class + 1e-8) # Avoid division by zero
96+
precision_per_class = self.tp / (
97+
self.tp + self.fp + 1e-8
98+
) # Avoid division by zero
99+
recall_per_class = self.tp / (
100+
self.tp + self.fn + 1e-8
101+
) # Avoid division by zero
102+
f1_per_class = (
103+
2
104+
* precision_per_class
105+
* recall_per_class
106+
/ (precision_per_class + recall_per_class + 1e-8)
107+
) # Avoid division by zero
98108

99109
# Take the average of F1 scores across all classes
100110
f1_score = torch.mean(f1_per_class)
@@ -136,4 +146,4 @@ def forward(self, preds, target):
136146
# Calculate Micro F1 score
137147
f1_score = self._micro_F1()
138148

139-
return f1_score
149+
return f1_score

utils/metrics/accuracy.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ def __init__(self, num_classes, macro_averaging=False):
77
super().__init__()
88
self.num_classes = num_classes
99
self.macro_averaging = macro_averaging
10-
10+
1111
def forward(self, y_true, y_pred):
1212
"""
1313
Compute the accuracy of the model.
@@ -30,7 +30,7 @@ def forward(self, y_true, y_pred):
3030
return self._macro_acc(y_true, y_pred)
3131
else:
3232
return self._micro_acc(y_true, y_pred)
33-
33+
3434
def _macro_acc(self, y_true, y_pred):
3535
"""
3636
Compute the macro-average accuracy.
@@ -51,15 +51,15 @@ def _macro_acc(self, y_true, y_pred):
5151

5252
classes = torch.unique(y_true) # Find unique class labels
5353
acc_per_class = []
54-
54+
5555
for c in classes:
56-
mask = (y_true == c) # Mask for class c
56+
mask = y_true == c # Mask for class c
5757
acc = (y_pred[mask] == y_true[mask]).float().mean() # Accuracy for class c
5858
acc_per_class.append(acc)
59-
59+
6060
macro_acc = torch.stack(acc_per_class).mean().item() # Average across classes
6161
return macro_acc
62-
62+
6363
def _micro_acc(self, y_true, y_pred):
6464
"""
6565
Compute the micro-average accuracy.
@@ -82,13 +82,21 @@ def _micro_acc(self, y_true, y_pred):
8282
if __name__ == "__main__":
8383
accuracy = Accuracy(5)
8484
macro_accuracy = Accuracy(5, macro_averaging=True)
85-
85+
8686
y_true = torch.tensor([0, 3, 2, 3, 4])
8787
y_pred = torch.tensor([0, 1, 2, 3, 4])
8888
print(accuracy(y_true, y_pred))
8989
print(macro_accuracy(y_true, y_pred))
90-
90+
9191
y_true = torch.tensor([0, 3, 2, 3, 4])
92-
y_onehot_pred = torch.tensor([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1]])
92+
y_onehot_pred = torch.tensor(
93+
[
94+
[1, 0, 0, 0, 0],
95+
[0, 1, 0, 0, 0],
96+
[0, 0, 1, 0, 0],
97+
[0, 0, 0, 1, 0],
98+
[0, 0, 0, 0, 1],
99+
]
100+
)
93101
print(accuracy(y_true, y_onehot_pred))
94102
print(macro_accuracy(y_true, y_onehot_pred))

0 commit comments

Comments
 (0)