55
66from .datasources import MNIST_SOURCE
77
8+
89class MNISTDataset4_9 (Dataset ):
910 """
1011 MNIST dataset of numbers 4-9.
@@ -18,37 +19,37 @@ class MNISTDataset4_9(Dataset):
1819 train : bool, optional
1920 Whether to train the model or not, by default False
2021 """
22+
2123 def __init__ (self , data_path : Path , sample_ids : np .ndarray , train : bool = False ):
2224 super .__init__ ()
2325 self .data_path = data_path
2426 self .mnist_path = self .data_path / "MNIST"
2527 self .samples = sample_ids
2628 self .train = train
27-
29+
2830 self .images_path = self .mnist_path / (
2931 MNIST_SOURCE ["train_images" ][1 ] if train else MNIST_SOURCE ["test_images" ][1 ]
3032 )
3133 self .labels_path = self .mnist_path / (
3234 MNIST_SOURCE ["train_labels" ][1 ] if train else MNIST_SOURCE ["test_labels" ][1 ]
3335 )
34-
35-
36+
3637 def __len__ (self ):
3738 return len (self .samples )
38-
39+
3940 def __getitem__ (self , idx ):
4041 with open (self .labels_path , "rb" ) as labelfile :
4142 label_pos = 8 + self .sample [idx ]
42- labelfile .seek (label_pos )
43- label = int .from_bytes (labelfile .read (1 ), byteorder = "big" )
43+ labelfile .seek (label_pos )
44+ label = int .from_bytes (labelfile .read (1 ), byteorder = "big" )
4445
4546 with open (self .images_path , "rb" ) as imagefile :
4647 image_pos = 16 + self .samples [idx ] * 28 * 28
4748 imagefile .seek (image_pos )
4849 image = np .frombuffer (imagefile .read (28 * 28 ), dtype = np .uint8 ).reshape (
4950 28 , 28
50- )
51+ )
5152
5253 image = np .expand_dims (image , axis = 0 ) # Channel
53-
54- return image , label
54+
55+ return image , label
0 commit comments