Skip to content

Commit ffe105f

Browse files
committed
Ruff + isort
1 parent c2c72ee commit ffe105f

File tree

6 files changed

+36
-30
lines changed

6 files changed

+36
-30
lines changed

CollaborativeCoding/dataloaders/uspsh5_7_9.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ class USPSH5_Digit_7_9_Dataset(Dataset):
3232
A transform function to apply to the images.
3333
"""
3434

35-
def __init__(self, data_path, sample_ids, train=False, transform=None, nr_channels=1):
35+
def __init__(
36+
self, data_path, sample_ids, train=False, transform=None, nr_channels=1
37+
):
3638
super().__init__()
3739
"""
3840
Initializes the USPS dataset by loading images and labels from the given `.h5` file.
@@ -112,7 +114,8 @@ def main():
112114
indices = np.array([7, 8, 9])
113115
# Load the dataset
114116
dataset = USPSH5_Digit_7_9_Dataset(
115-
data_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git", sample_ids=indices,
117+
data_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git",
118+
sample_ids=indices,
116119
train=False,
117120
transform=transform,
118121
)

CollaborativeCoding/metrics/F1.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,13 @@ def __returnmetric__(self):
159159
else:
160160
self.y_true = torch.cat(self.y_true)
161161
self.y_pred = torch.cat(self.y_pred)
162-
return self._micro_F1(self.y_true, self.y_pred) if not self.macro_averaging else self._macro_F1(self.y_true, self.y_pred)
162+
return (
163+
self._micro_F1(self.y_true, self.y_pred)
164+
if not self.macro_averaging
165+
else self._macro_F1(self.y_true, self.y_pred)
166+
)
163167

164168
def __reset__(self):
165169
self.y_true = []
166170
self.y_pred = []
167171
return None
168-
169-

CollaborativeCoding/models/solveig_model.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,24 @@
44

55
def find_fc_input_shape(image_shape, model):
66
"""
7-
Find the shape of the input to the fully connected layer after passing through the convolutional layers.
8-
9-
Code inspired by @Seilmast (https://github.com/SFI-Visual-Intelligence/Collaborative-Coding-Exam/issues/67#issuecomment-2651212254)
10-
11-
Args
12-
----
13-
image_shape : tuple(int, int, int)
14-
Shape of the input image (C, H, W), where C is the number of channels,
15-
H is the height, and W is the width of the image.
16-
model : nn.Module
17-
The CNN model containing the convolutional layers, whose output size is used to
18-
determine the number of input features for the fully connected layer.
19-
20-
Returns
21-
-------
22-
int
23-
The number of elements in the input to the fully connected layer.
24-
"""
7+
Find the shape of the input to the fully connected layer after passing through the convolutional layers.
8+
9+
Code inspired by @Seilmast (https://github.com/SFI-Visual-Intelligence/Collaborative-Coding-Exam/issues/67#issuecomment-2651212254)
10+
11+
Args
12+
----
13+
image_shape : tuple(int, int, int)
14+
Shape of the input image (C, H, W), where C is the number of channels,
15+
H is the height, and W is the width of the image.
16+
model : nn.Module
17+
The CNN model containing the convolutional layers, whose output size is used to
18+
determine the number of input features for the fully connected layer.
19+
20+
Returns
21+
-------
22+
int
23+
The number of elements in the input to the fully connected layer.
24+
"""
2525

2626
dummy_img = torch.randn(1, *image_shape)
2727
with torch.no_grad():

main.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import numpy as np
22
import torch as th
33
import torch.nn as nn
4+
import wandb
45
from torch.utils.data import DataLoader
56
from torchvision import transforms
67
from tqdm import tqdm
78

8-
import wandb
99
from CollaborativeCoding import (
1010
MetricWrapper,
1111
createfolders,
@@ -17,7 +17,6 @@
1717
# from wandb_api import WANDB_API
1818

1919

20-
2120
def main():
2221
"""
2322

tests/test_metrics.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,12 @@ def test_f1score():
9191
macro_f1_score = f1_macro.__returnmetric__()
9292

9393
# Check if outputs are tensors
94-
assert isinstance(micro_f1_score, torch.Tensor), "Micro F1 score should be a tensor."
95-
assert isinstance(macro_f1_score, torch.Tensor), "Macro F1 score should be a tensor."
94+
assert isinstance(micro_f1_score, torch.Tensor), (
95+
"Micro F1 score should be a tensor."
96+
)
97+
assert isinstance(macro_f1_score, torch.Tensor), (
98+
"Macro F1 score should be a tensor."
99+
)
96100

97101
# Check that F1 scores are between 0 and 1
98102
assert 0 <= micro_f1_score.item() <= 1, "Micro F1 score should be between 0 and 1."

tests/test_wrappers.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,7 @@ def test_load_data(data_name, expected):
6565
assert len(dataset) > 0
6666
assert isinstance(dataset[0], tuple)
6767
assert isinstance(dataset[0][0], torch.Tensor)
68-
assert isinstance(
69-
dataset[0][1], int
70-
)
68+
assert isinstance(dataset[0][1], int)
7169

7270

7371
def test_load_metric():

0 commit comments

Comments
 (0)