Skip to content

Commit 2f227a6

Browse files
authored
Merge pull request #113 from SFI-Visual-Intelligence/mag-branch
SVHN bug fix
2 parents c3c546e + a932d64 commit 2f227a6

File tree

5 files changed

+18
-13
lines changed

5 files changed

+18
-13
lines changed

CollaborativeCoding/dataloaders/download.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,11 @@ def download_svhn(path, train: bool = True):
9898
train_data = parent_path / "train_32x32.mat"
9999
test_data = parent_path / "test_32x32.mat"
100100

101-
if not train_data.is_file():
101+
if not train_data.exists():
102102
download_svhn(parent_path, train=True)
103-
if not test_data.is_file():
103+
if not test_data.exists():
104104
download_svhn(parent_path, train=False)
105-
print(test_data)
105+
106106
train_labels = loadmat(train_data)["y"]
107107
test_labels = loadmat(test_data)["y"]
108108

CollaborativeCoding/dataloaders/svhn.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(
3131
"""
3232
super().__init__()
3333

34-
self.data_path = data_path
34+
self.data_path = data_path / "SVHN"
3535
self.indexes = sample_ids
3636
self.split = "train" if train else "test"
3737

@@ -41,7 +41,7 @@ def __init__(
4141
if not os.path.exists(
4242
os.path.join(self.data_path, f"svhn_{self.split}data.h5")
4343
):
44-
self._download_data(self.data_path)
44+
self._create_h5py(self.data_path)
4545

4646
assert os.path.exists(
4747
os.path.join(self.data_path, f"svhn_{self.split}data.h5")
@@ -53,15 +53,14 @@ def __init__(
5353

5454
self.num_classes = len(np.unique(self.labels))
5555

56-
def _download_data(self, path: str):
56+
def _create_h5py(self, path: str):
5757
"""
5858
Downloads the SVHN dataset to the specified directory.
5959
Args:
6060
path (str): The directory where the dataset will be downloaded.
6161
"""
6262
print(f"Downloading SVHN data into {path}")
6363

64-
SVHN(path, split=self.split, download=True)
6564
data = loadmat(os.path.join(path, f"{self.split}_32x32.mat"))
6665

6766
images, labels = data["X"], data["y"]

CollaborativeCoding/dataloaders/uspsh5_7_9.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,3 @@ def __getitem__(self, id):
102102
image = self.transform(image)
103103

104104
return image, label
105-

CollaborativeCoding/metrics/F1.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ def __init__(self, num_classes, macro_averaging=False):
2323
self.y_true = []
2424
self.y_pred = []
2525

26-
2726
def forward(self, target, preds):
2827
"""
2928
Stores predictions and targets for computing the F1 score.
@@ -57,7 +56,11 @@ def compute_f1(self):
5756
y_true = torch.cat(self.y_true)
5857
y_pred = torch.cat(self.y_pred)
5958

60-
return self._macro_F1(y_true, y_pred) if self.macro_averaging else self._micro_F1(y_true, y_pred)
59+
return (
60+
self._macro_F1(y_true, y_pred)
61+
if self.macro_averaging
62+
else self._micro_F1(y_true, y_pred)
63+
)
6164

6265
def _micro_F1(self, target, preds):
6366
"""Computes Micro F1 Score (global TP, FP, FN)."""
@@ -111,9 +114,13 @@ def __returnmetric__(self):
111114
y_true = torch.cat([t.unsqueeze(0) if t.dim() == 0 else t for t in self.y_true])
112115
y_pred = torch.cat([t.unsqueeze(0) if t.dim() == 0 else t for t in self.y_pred])
113116

114-
return self._macro_F1(y_true, y_pred) if self.macro_averaging else self._micro_F1(y_true, y_pred)
117+
return (
118+
self._macro_F1(y_true, y_pred)
119+
if self.macro_averaging
120+
else self._micro_F1(y_true, y_pred)
121+
)
115122

116123
def __reset__(self):
117124
"""Resets stored predictions and targets."""
118125
self.y_true = []
119-
self.y_pred = []
126+
self.y_pred = []

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import numpy as np
22
import torch as th
33
import torch.nn as nn
4+
import wandb
45
from torch.utils.data import DataLoader
56
from torchvision import transforms
67
from tqdm import tqdm
78

8-
import wandb
99
from CollaborativeCoding import (
1010
MetricWrapper,
1111
createfolders,

0 commit comments

Comments
 (0)