|
1 | 1 | import bz2 |
| 2 | +import gzip |
2 | 3 | import hashlib |
3 | 4 | import os |
4 | | -import gzip |
5 | 5 | from pathlib import Path |
6 | 6 | from tempfile import TemporaryDirectory |
7 | 7 | from urllib.request import urlretrieve |
8 | 8 |
|
9 | 9 | import h5py as h5 |
10 | 10 | import numpy as np |
11 | 11 |
|
12 | | -from .datasources import USPS_SOURCE, MNIST_SOURCE |
| 12 | +from .datasources import MNIST_SOURCE, USPS_SOURCE |
13 | 13 |
|
14 | 14 |
|
15 | 15 | class Downloader: |
@@ -52,35 +52,36 @@ def _chech_is_downloaded(path: Path) -> bool: |
52 | 52 | else: |
53 | 53 | path.mkdir(parents=True, exist_ok=True) |
54 | 54 | return False |
55 | | - |
| 55 | + |
56 | 56 | def _download_data(path: Path) -> None: |
57 | 57 | urls = {key: MNIST_SOURCE[key][0] for key in MNIST_SOURCE.keys()} |
58 | 58 |
|
59 | 59 | for name, url in urls.items(): |
60 | 60 | file_path = os.path.join(path, url.split("/")[-1]) |
61 | | - if not os.path.exists(file_path.replace(".gz", "")): # Avoid re-downloading |
| 61 | + if not os.path.exists( |
| 62 | + file_path.replace(".gz", "") |
| 63 | + ): # Avoid re-downloading |
62 | 64 | urlretrieve(url, file_path) |
63 | 65 | with gzip.open(file_path, "rb") as f_in: |
64 | 66 | with open(file_path.replace(".gz", ""), "wb") as f_out: |
65 | 67 | f_out.write(f_in.read()) |
66 | 68 | os.remove(file_path) # Remove compressed file |
67 | | - |
| 69 | + |
68 | 70 | def _get_labels(path: Path) -> np.ndarray: |
69 | 71 | with open(path, "rb") as f: |
70 | 72 | data = np.frombuffer(f.read(), dtype=np.uint8, offset=8) |
71 | 73 | return data |
72 | | - |
| 74 | + |
73 | 75 | if not _chech_is_downloaded(data_dir): |
74 | 76 | _download_data(data_dir) |
75 | | - |
| 77 | + |
76 | 78 | train_labels_path = data_dir / "MNIST" / MNIST_SOURCE["train_labels"][1] |
77 | 79 | test_labels_path = data_dir / "MNIST" / MNIST_SOURCE["test_labels"][1] |
78 | | - |
| 80 | + |
79 | 81 | train_labels = _get_labels(train_labels_path) |
80 | 82 | test_labels = _get_labels(test_labels_path) |
81 | | - |
| 83 | + |
82 | 84 | return train_labels, test_labels |
83 | | - |
84 | 85 |
|
85 | 86 | def svhn(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]: |
86 | 87 | raise NotImplementedError("SVHN download not implemented yet") |
|
0 commit comments