Skip to content

Commit 64fac10

Browse files
committed
ruffedisorted
1 parent 5d0d296 commit 64fac10

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

utils/dataloaders/mnist_4_9.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from .datasources import MNIST_SOURCE
77

8+
89
class MNISTDataset4_9(Dataset):
910
"""
1011
MNIST dataset of numbers 4-9.
@@ -18,37 +19,37 @@ class MNISTDataset4_9(Dataset):
1819
train : bool, optional
1920
Whether to train the model or not, by default False
2021
"""
22+
2123
def __init__(self, data_path: Path, sample_ids: np.ndarray, train: bool = False):
2224
super.__init__()
2325
self.data_path = data_path
2426
self.mnist_path = self.data_path / "MNIST"
2527
self.samples = sample_ids
2628
self.train = train
27-
29+
2830
self.images_path = self.mnist_path / (
2931
MNIST_SOURCE["train_images"][1] if train else MNIST_SOURCE["test_images"][1]
3032
)
3133
self.labels_path = self.mnist_path / (
3234
MNIST_SOURCE["train_labels"][1] if train else MNIST_SOURCE["test_labels"][1]
3335
)
34-
35-
36+
3637
def __len__(self):
3738
return len(self.samples)
38-
39+
3940
def __getitem__(self, idx):
4041
with open(self.labels_path, "rb") as labelfile:
4142
label_pos = 8 + self.sample[idx]
42-
labelfile.seek(label_pos)
43-
label = int.from_bytes(labelfile.read(1), byteorder="big")
43+
labelfile.seek(label_pos)
44+
label = int.from_bytes(labelfile.read(1), byteorder="big")
4445

4546
with open(self.images_path, "rb") as imagefile:
4647
image_pos = 16 + self.samples[idx] * 28 * 28
4748
imagefile.seek(image_pos)
4849
image = np.frombuffer(imagefile.read(28 * 28), dtype=np.uint8).reshape(
4950
28, 28
50-
)
51+
)
5152

5253
image = np.expand_dims(image, axis=0) # Channel
53-
54-
return image, label
54+
55+
return image, label

0 commit comments

Comments
 (0)