Skip to content

Commit 6d2b2e6

Browse files
committed
Fixed a bug where SVHN was downloaded twice
1 parent 1327fde commit 6d2b2e6

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

CollaborativeCoding/dataloaders/download.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,11 @@ def download_svhn(path, train: bool = True):
9898
train_data = parent_path / "train_32x32.mat"
9999
test_data = parent_path / "test_32x32.mat"
100100

101-
if not train_data.is_file():
101+
if not train_data.exists():
102102
download_svhn(parent_path, train=True)
103-
if not test_data.is_file():
103+
if not test_data.exists():
104104
download_svhn(parent_path, train=False)
105-
print(test_data)
105+
106106
train_labels = loadmat(train_data)["y"]
107107
test_labels = loadmat(test_data)["y"]
108108

CollaborativeCoding/dataloaders/svhn.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(
3131
"""
3232
super().__init__()
3333

34-
self.data_path = data_path
34+
self.data_path = data_path / "SVHN"
3535
self.indexes = sample_ids
3636
self.split = "train" if train else "test"
3737

@@ -41,7 +41,7 @@ def __init__(
4141
if not os.path.exists(
4242
os.path.join(self.data_path, f"svhn_{self.split}data.h5")
4343
):
44-
self._download_data(self.data_path)
44+
self._create_h5py(self.data_path)
4545

4646
assert os.path.exists(
4747
os.path.join(self.data_path, f"svhn_{self.split}data.h5")
@@ -53,15 +53,14 @@ def __init__(
5353

5454
self.num_classes = len(np.unique(self.labels))
5555

56-
def _download_data(self, path: str):
56+
def _create_h5py(self, path: str):
5757
"""
5858
Downloads the SVHN dataset to the specified directory.
5959
Args:
6060
path (str): The directory where the dataset will be downloaded.
6161
"""
6262
print(f"Downloading SVHN data into {path}")
6363

64-
SVHN(path, split=self.split, download=True)
6564
data = loadmat(os.path.join(path, f"{self.split}_32x32.mat"))
6665

6766
images, labels = data["X"], data["y"]

0 commit comments

Comments
 (0)