Skip to content

Commit 4481ef8

Browse files
committed
Trying to fix ruff and isort on push
1 parent d7e83d4 commit 4481ef8

File tree

4 files changed

+46
-37
lines changed

4 files changed

+46
-37
lines changed

tests/test_dataloaders.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,12 @@ def test_svhn_dataset():
4848
)
4949

5050
assert dataset.__len__() != 0
51-
assert os.path.exists(os.path.join(tempdir, "test_32x32.mat")), f'No such file as test_32x32.mat. Try running download=True'
52-
assert os.path.exists(os.path.join(tempdir, "svhn_testdata.h5")), f'No such file as svhn_testdata.h5. Try running download=True'
51+
assert os.path.exists(os.path.join(tempdir, "test_32x32.mat")), (
52+
f"No such file as test_32x32.mat. Try running download=True"
53+
)
54+
assert os.path.exists(os.path.join(tempdir, "svhn_testdata.h5")), (
55+
f"No such file as svhn_testdata.h5. Try running download=True"
56+
)
5357

5458
img, label = dataset.__getitem__(0)
5559
assert len(img.size()) == 3 and img.size() == (1, 28, 28) and img.size(0) == 1

tests/test_metrics.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,14 @@ def test_entropypred():
104104

105105
pred_logits = th.rand(6, 5)
106106
true_lab = th.rand(6, 5)
107-
107+
108108
metric = EntropyPrediction(averages="mean")
109109
metric2 = EntropyPrediction(averages="sum")
110-
110+
111111
# Test for averaging metric consistency
112112
metric(true_lab, pred_logits)
113113
metric2(true_lab, pred_logits)
114-
assert (th.abs(th.sum(6 * metric.__returnmetric__() - metric2.__returnmetric__())) < 1e-5)
114+
assert (
115+
th.abs(th.sum(6 * metric.__returnmetric__() - metric2.__returnmetric__()))
116+
< 1e-5
117+
)

utils/dataloaders/svhn.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22

3-
import h5py
3+
import h5py
44
import numpy as np
55
from PIL import Image
66
from scipy.io import loadmat
@@ -37,15 +37,16 @@ def __init__(
3737

3838
self.nr_channels = nr_channels
3939
self.transforms = transform
40-
41-
42-
assert os.path.exists(os.path.join(self.data_path, f'svhn_{self.split}data.h5')), f'File svhn_{self.split}data.h5 does not exists. Run download=True'
43-
with h5py.File(os.path.join(self.data_path, f'svhn_{self.split}data.h5'), 'r') as h5f:
44-
self.labels = h5f['labels'][:]
45-
40+
41+
assert os.path.exists(
42+
os.path.join(self.data_path, f"svhn_{self.split}data.h5")
43+
), f"File svhn_{self.split}data.h5 does not exists. Run download=True"
44+
with h5py.File(
45+
os.path.join(self.data_path, f"svhn_{self.split}data.h5"), "r"
46+
) as h5f:
47+
self.labels = h5f["labels"][:]
48+
4649
self.num_classes = len(np.unique(self.labels))
47-
48-
4950

5051
def _download_data(self, path: str):
5152
"""
@@ -55,17 +56,19 @@ def _download_data(self, path: str):
5556
"""
5657
print(f"Downloading SVHN data into {path}")
5758
SVHN(path, split=self.split, download=True)
58-
data = loadmat(os.path.join(path, f'{self.split}_32x32.mat'))
59+
data = loadmat(os.path.join(path, f"{self.split}_32x32.mat"))
5960

60-
images, labels = data['X'], data['y']
61-
images = images.transpose(3,1,0,2)
61+
images, labels = data["X"], data["y"]
62+
images = images.transpose(3, 1, 0, 2)
6263
labels[labels == 10] = 0
6364
labels = labels.flatten()
64-
65-
with h5py.File(os.path.join(self.data_path, f'svhn_{self.split}data.h5'), 'w') as h5f:
66-
h5f.create_dataset('images', data=images)
67-
h5f.create_dataset('labels', data=labels)
68-
65+
66+
with h5py.File(
67+
os.path.join(self.data_path, f"svhn_{self.split}data.h5"), "w"
68+
) as h5f:
69+
h5f.create_dataset("images", data=images)
70+
h5f.create_dataset("labels", data=labels)
71+
6972
def __len__(self):
7073
"""
7174
Returns the number of samples in the dataset.
@@ -83,14 +86,15 @@ def __getitem__(self, index):
8386
tuple: A tuple containing the image and its corresponding label.
8487
"""
8588
lab = self.labels[index]
86-
with h5py.File(os.path.join(self.data_path, f'svhn_{self.split}data.h5'), 'r') as h5f:
87-
img = Image.fromarray(h5f['images'][index])
88-
89+
with h5py.File(
90+
os.path.join(self.data_path, f"svhn_{self.split}data.h5"), "r"
91+
) as h5f:
92+
img = Image.fromarray(h5f["images"][index])
93+
8994
if self.nr_channels == 1:
90-
img = img.convert('L')
91-
95+
img = img.convert("L")
96+
9297
if self.transforms is not None:
9398
img = self.transforms(img)
9499

95100
return img, lab
96-

utils/metrics/EntropyPred.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import torch as th
2-
import torch.nn as nn
31
import numpy as np
2+
import torch as th
3+
import torch.nn as nn
44
from scipy.stats import entropy
55

66

@@ -34,21 +34,20 @@ def __call__(self, y_true: th.Tensor, y_logits: th.Tensor):
3434
torch.Tensor: The aggregated entropy value(s) based on the specified
3535
method ('mean', 'sum', or 'none').
3636
"""
37-
38-
assert len(y_logits.size()) == 2, f'y_logits shape: {y_logits.size()}'
37+
38+
assert len(y_logits.size()) == 2, f"y_logits shape: {y_logits.size()}"
3939
y_pred = nn.Softmax(dim=1)(y_logits)
40-
print(f'y_pred: {y_pred}')
40+
print(f"y_pred: {y_pred}")
4141
entropy_values = entropy(y_pred, axis=1)
4242
entropy_values = th.from_numpy(entropy_values)
4343

4444
# Fix numerical errors for perfect guesses
4545
entropy_values[entropy_values == th.inf] = 0
4646
entropy_values = th.nan_to_num(entropy_values)
47-
print(f'Entropy Values: {entropy_values}')
47+
print(f"Entropy Values: {entropy_values}")
4848
for sample in entropy_values:
4949
self.stored_entropy_values.append(sample.item())
5050

51-
5251
def __returnmetric__(self):
5352
stored_entropy_values = th.from_numpy(np.asarray(self.stored_entropy_values))
5453

@@ -57,9 +56,8 @@ def __returnmetric__(self):
5756
elif self.averages == "sum":
5857
stored_entropy_values = th.sum(stored_entropy_values)
5958
elif self.averages == "none":
60-
pass
59+
pass
6160
return stored_entropy_values
6261

6362
def __reset__(self):
6463
self.stored_entropy_values = []
65-

0 commit comments

Comments
 (0)