Skip to content

Commit e382fb9

Browse files
committed
Format files
1 parent 5b52a95 commit e382fb9

File tree

3 files changed

+59
-50
lines changed

3 files changed

+59
-50
lines changed

utils/createfolders.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import argparse
21
from pathlib import Path
3-
from tempfile import TemporaryDirectory
42

53

64
def createfolders(*dirs: Path) -> None:

utils/dataloaders/uspsh5_7_9.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from torch.utils.data import Dataset
2-
import numpy as np
31
import h5py
4-
from torchvision import transforms
5-
from PIL import Image
2+
import numpy as np
63
import torch
4+
from PIL import Image
5+
from torch.utils.data import Dataset
6+
from torchvision import transforms
77

88

99
class USPSH5_Digit_7_9_Dataset(Dataset):
@@ -95,14 +95,20 @@ def __getitem__(self, id):
9595

9696
def main():
9797
# Example Usage:
98-
transform = transforms.Compose([
99-
transforms.Resize((16, 16)), # Ensure images are 16x16
100-
transforms.ToTensor(),
101-
transforms.Normalize((0.5,), (0.5,)) # Normalize to [-1, 1]
102-
])
98+
transform = transforms.Compose(
99+
[
100+
transforms.Resize((16, 16)), # Ensure images are 16x16
101+
transforms.ToTensor(),
102+
transforms.Normalize((0.5,), (0.5,)), # Normalize to [-1, 1]
103+
]
104+
)
103105

104106
# Load the dataset
105-
dataset = USPSH5_Digit_7_9_Dataset(h5_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git/usps.h5", mode="train", transform=transform)
107+
dataset = USPSH5_Digit_7_9_Dataset(
108+
h5_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git/usps.h5",
109+
mode="train",
110+
transform=transform,
111+
)
106112
data_loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)
107113
batch = next(iter(data_loader)) # grab a batch from the dataloader
108114
img, label = batch
@@ -112,5 +118,6 @@ def main():
112118
# Check dataset size
113119
print(f"Dataset size: {len(dataset)}")
114120

115-
if __name__ == '__main__':
116-
main()
121+
122+
if __name__ == "__main__":
123+
main()

utils/metrics/F1.py

Lines changed: 40 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,41 @@
1-
import torch.nn as nn
21
import torch
2+
import torch.nn as nn
33

44

55
class F1Score(nn.Module):
66
"""
7-
F1 Score implementation with direct averaging inside the compute method.
7+
F1 Score implementation with direct averaging inside the compute method.
8+
9+
Parameters
10+
----------
11+
num_classes : int
12+
Number of classes.
813
9-
Parameters
10-
----------
11-
num_classes : int
12-
Number of classes.
14+
Attributes
15+
----------
16+
num_classes : int
17+
The number of classes.
1318
14-
Attributes
15-
----------
16-
num_classes : int
17-
The number of classes.
19+
tp : torch.Tensor
20+
Tensor for True Positives (TP) for each class.
1821
19-
tp : torch.Tensor
20-
Tensor for True Positives (TP) for each class.
22+
fp : torch.Tensor
23+
Tensor for False Positives (FP) for each class.
2124
22-
fp : torch.Tensor
23-
Tensor for False Positives (FP) for each class.
25+
fn : torch.Tensor
26+
Tensor for False Negatives (FN) for each class.
27+
"""
2428

25-
fn : torch.Tensor
26-
Tensor for False Negatives (FN) for each class.
27-
"""
2829
def __init__(self, num_classes):
2930
"""
30-
Initializes the F1Score object, setting up the necessary state variables.
31+
Initializes the F1Score object, setting up the necessary state variables.
3132
32-
Parameters
33-
----------
34-
num_classes : int
35-
The number of classes in the classification task.
33+
Parameters
34+
----------
35+
num_classes : int
36+
The number of classes in the classification task.
3637
37-
"""
38+
"""
3839

3940
super().__init__()
4041

@@ -47,16 +48,16 @@ def __init__(self, num_classes):
4748

4849
def update(self, preds, target):
4950
"""
50-
Update the variables with predictions and true labels.
51+
Update the variables with predictions and true labels.
5152
52-
Parameters
53-
----------
54-
preds : torch.Tensor
55-
Predicted logits (shape: [batch_size, num_classes]).
53+
Parameters
54+
----------
55+
preds : torch.Tensor
56+
Predicted logits (shape: [batch_size, num_classes]).
5657
57-
target : torch.Tensor
58-
True labels (shape: [batch_size]).
59-
"""
58+
target : torch.Tensor
59+
True labels (shape: [batch_size]).
60+
"""
6061
preds = torch.argmax(preds, dim=1)
6162

6263
# Calculate True Positives (TP), False Positives (FP), and False Negatives (FN) per class
@@ -76,17 +77,20 @@ def compute(self):
7677
"""
7778

7879
# Compute F1 score based on the specified averaging method
79-
f1_score = 2 * torch.sum(self.tp) / (2 * torch.sum(self.tp) + torch.sum(self.fp) + torch.sum(self.fn))
80+
f1_score = (
81+
2
82+
* torch.sum(self.tp)
83+
/ (2 * torch.sum(self.tp) + torch.sum(self.fp) + torch.sum(self.fn))
84+
)
8085

8186
return f1_score
8287

8388

8489
def test_f1score():
8590
f1_metric = F1Score(num_classes=3)
86-
preds = torch.tensor([[0.8, 0.1, 0.1],
87-
[0.2, 0.7, 0.1],
88-
[0.2, 0.3, 0.5],
89-
[0.1, 0.2, 0.7]])
91+
preds = torch.tensor(
92+
[[0.8, 0.1, 0.1], [0.2, 0.7, 0.1], [0.2, 0.3, 0.5], [0.1, 0.2, 0.7]]
93+
)
9094

9195
target = torch.tensor([0, 1, 0, 2])
9296

0 commit comments

Comments
 (0)