Skip to content

Commit 601caca

Browse files
committed
ruffed, isorted
1 parent 15c99ea commit 601caca

File tree

3 files changed

+35
-25
lines changed

3 files changed

+35
-25
lines changed

utils/dataloaders/datasources.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,24 @@
1919
}
2020

2121
MNIST_SOURCE = {
22-
"train_images": ["https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz",
23-
"train-images-idx3-ubyte",
24-
None
22+
"train_images": [
23+
"https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz",
24+
"train-images-idx3-ubyte",
25+
None,
2526
],
26-
"train_labels": ["https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz",
27-
"train-labels-idx1-ubyte",
28-
None
27+
"train_labels": [
28+
"https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz",
29+
"train-labels-idx1-ubyte",
30+
None,
2931
],
30-
"test_images": ["https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz",
31-
"t10k-images-idx3-ubyte",
32-
None
32+
"test_images": [
33+
"https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz",
34+
"t10k-images-idx3-ubyte",
35+
None,
3336
],
34-
"test_labels": ["https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz",
35-
"t10k-labels-idx1-ubyte",
36-
None
37+
"test_labels": [
38+
"https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz",
39+
"t10k-labels-idx1-ubyte",
40+
None,
3741
],
3842
}

utils/dataloaders/download.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import bz2
2+
import gzip
23
import hashlib
34
import os
4-
import gzip
55
from pathlib import Path
66
from tempfile import TemporaryDirectory
77
from urllib.request import urlretrieve
88

99
import h5py as h5
1010
import numpy as np
1111

12-
from .datasources import USPS_SOURCE, MNIST_SOURCE
12+
from .datasources import MNIST_SOURCE, USPS_SOURCE
1313

1414

1515
class Downloader:
@@ -52,35 +52,36 @@ def _chech_is_downloaded(path: Path) -> bool:
5252
else:
5353
path.mkdir(parents=True, exist_ok=True)
5454
return False
55-
55+
5656
def _download_data(path: Path) -> None:
5757
urls = {key: MNIST_SOURCE[key][0] for key in MNIST_SOURCE.keys()}
5858

5959
for name, url in urls.items():
6060
file_path = os.path.join(path, url.split("/")[-1])
61-
if not os.path.exists(file_path.replace(".gz", "")): # Avoid re-downloading
61+
if not os.path.exists(
62+
file_path.replace(".gz", "")
63+
): # Avoid re-downloading
6264
urlretrieve(url, file_path)
6365
with gzip.open(file_path, "rb") as f_in:
6466
with open(file_path.replace(".gz", ""), "wb") as f_out:
6567
f_out.write(f_in.read())
6668
os.remove(file_path) # Remove compressed file
67-
69+
6870
def _get_labels(path: Path) -> np.ndarray:
6971
with open(path, "rb") as f:
7072
data = np.frombuffer(f.read(), dtype=np.uint8, offset=8)
7173
return data
72-
74+
7375
if not _chech_is_downloaded(data_dir):
7476
_download_data(data_dir)
75-
77+
7678
train_labels_path = data_dir / "MNIST" / MNIST_SOURCE["train_labels"][1]
7779
test_labels_path = data_dir / "MNIST" / MNIST_SOURCE["test_labels"][1]
78-
80+
7981
train_labels = _get_labels(train_labels_path)
8082
test_labels = _get_labels(test_labels_path)
81-
83+
8284
return train_labels, test_labels
83-
8485

8586
def svhn(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]:
8687
raise NotImplementedError("SVHN download not implemented yet")

utils/dataloaders/mnist_0_3.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44
from torch.utils.data import Dataset
5+
56
from .datasources import MNIST_SOURCE
67

78

@@ -62,11 +63,15 @@ def __init__(
6263
self.transform = transform
6364
self.num_classes = 4
6465

65-
self.images_path = self.mnist_path / (MNIST_SOURCE["train_images"][1] if train else MNIST_SOURCE["test_images"][1])
66-
self.labels_path = self.mnist_path / (MNIST_SOURCE["train_labels"][1] if train else MNIST_SOURCE["test_labels"][1])
66+
self.images_path = self.mnist_path / (
67+
MNIST_SOURCE["train_images"][1] if train else MNIST_SOURCE["test_images"][1]
68+
)
69+
self.labels_path = self.mnist_path / (
70+
MNIST_SOURCE["train_labels"][1] if train else MNIST_SOURCE["test_labels"][1]
71+
)
6772

6873
self.length = len(self.idx)
69-
74+
7075
def __len__(self):
7176
return self.length
7277

0 commit comments

Comments
 (0)