1- import gzip
2- import os
3- import urllib .request as ur
41from pathlib import Path
52
3+ import numpy as np
64from torch .utils .data import Dataset
75
6+ from .datasources import MNIST_SOURCE
87
98class MNIST_4_9 (Dataset ):
10- def __init__ (self , datapath : Path , train : bool = False , download : bool = False ):
9+ """
10+ MNIST dataset of numbers 4-9.
11+
12+ Parameters
13+ ----------
14+ data_path : Path
15+ Root directory where MNIST dataset is stored
16+ sample_ids : np.ndarray
17+ Array of indices spcifying which samples to load. This determines the samples used by the dataloader.
18+ train : bool, optional
19+ Whether to train the model or not, by default False
20+ """
21+ def __init__ (self , data_path : Path , sample_ids : np .ndarray , train : bool = False ):
1122 super .__init__ ()
12- self .datapath = datapath
13- self .mnist_path = self .datapath / "MNIST"
23+ self .data_path = data_path
24+ self .mnist_path = self .data_path / "MNIST"
25+ self .samples = sample_ids
1426 self .train = train
15- self .download = download
16- self .num_classes : int = 6
17-
18- if not self .download and not self ._already_downloaded ():
19- raise FileNotFoundError (
20- "Data files are not found. Set --download-data=True to download the data"
21- )
22- if self .download and not self ._already_downloaded ():
23- self ._download ()
24-
25- def _download (self ):
26- urls : dict = {
27- "train_images" : "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz" ,
28- "train_labels" : "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz" ,
29- "test_images" : "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz" ,
30- "test_labels" : "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz" ,
31- }
32-
33- for url in urls .values ():
34- file_path : Path = os .path .join (self .mnist_path , url .split ("/" )[- 1 ])
35- file_name : Path = file_path .replace (".gz" , "" )
36- if os .path .exists (file_name ):
37- print (f"File: { file_name } already downloaded" )
38- else :
39- print (f"File: { file_name } is downloading..." )
40- ur .urlretrieve (url , file_path ) # Download file
41- with gzip .open (file_path , "rb" ) as infile :
42- with open (file_name , "wb" ) as outfile :
43- outfile .write (infile .read ()) # Write from url to local file
44- os .remove (file_path ) # remove .gz file
45-
46- def _already_downloaded (self ):
47- if self .mnist_path .exists ():
48- required_files : list = [
49- "train-images-idx3-ubyte" ,
50- "train-labels-idx1-ubyte" ,
51- "t10k-images-idx3-ubyte" ,
52- "t10k-labels-idx1-ubyte" ,
53- ]
54- return all ([(self .mnist_path / file ).exists () for file in required_files ])
55-
56- else :
57- self .mnist_path .mkdir (parents = True , exist_ok = True )
58- return False
59-
27+
28+ self .images_path = self .mnist_path / (
29+ MNIST_SOURCE ["train_images" ][1 ] if train else MNIST_SOURCE ["test_images" ][1 ]
30+ )
31+ self .labels_path = self .mnist_path / (
32+ MNIST_SOURCE ["train_labels" ][1 ] if train else MNIST_SOURCE ["test_labels" ][1 ]
33+ )
34+
35+
6036 def __len__ (self ):
61- pass
62-
63- def __getitem__ (self ):
64- pass
37+ return len (self .samples )
38+
39+ def __getitem__ (self , idx ):
40+ with open (self .labels_path , "rb" ) as labelfile :
41+ label_pos = 8 + self .sample [idx ]
42+ labelfile .seek (label_pos )
43+ label = int .from_bytes (labelfile .read (1 ), byteorder = "big" )
44+
45+ with open (self .images_path , "rb" ) as imagefile :
46+ image_pos = 16 + self .samples [idx ] * 28 * 28
47+ imagefile .seek (image_pos )
48+ image = np .frombuffer (imagefile .read (28 * 28 ), dtype = np .uint8 ).reshape (
49+ 28 , 28
50+ )
51+
52+ image = np .expand_dims (image , axis = 0 ) # Channel
53+
54+ return image , label
0 commit comments