1+ from pathlib import Path
2+
3+ from torch .utils .data import Dataset
4+ import numpy as np
5+ import urllib .request
6+ import gzip
7+ import os
8+
9+
10+
11+ class MNISTDataset0_3 (Dataset ):
12+ """
13+ A custom dataset class for loading MNIST data, specifically for digits 0 through 3.
14+ Parameters
15+ ----------
16+ data_path : Path
17+ The root directory where the MNIST data is stored or will be downloaded.
18+ train : bool, optional
19+ If True, loads the training data, otherwise loads the test data. Default is False.
20+ transform : callable, optional
21+ A function/transform that takes in an image and returns a transformed version. Default is None.
22+ download : bool, optional
23+ If True, downloads the dataset if it is not already present in the specified data_path. Default is False.
24+ Attributes
25+ ----------
26+ data_path : Path
27+ The root directory where the MNIST data is stored.
28+ mnist_path : Path
29+ The directory where the MNIST data files are stored.
30+ train : bool
31+ Indicates whether the training data or test data is being used.
32+ transform : callable
33+ A function/transform that takes in an image and returns a transformed version.
34+ download : bool
35+ Indicates whether the dataset should be downloaded if not present.
36+ images_path : Path
37+ The path to the image file (training or test) based on the `train` flag.
38+ labels_path : Path
39+ The path to the label file (training or test) based on the `train` flag.
40+ idx : numpy.ndarray
41+ Indices of the labels that are less than 4.
42+ length : int
43+ The number of samples in the dataset.
44+ Methods
45+ -------
46+ _parse_labels(train)
47+ Parses the labels from the label file.
48+ _chech_is_downloaded()
49+ Checks if the dataset is already downloaded.
50+ _download_data()
51+ Downloads and extracts the MNIST dataset.
52+ __len__()
53+ Returns the number of samples in the dataset.
54+ __getitem__(index)
55+ Returns the image and label at the specified index.
56+ """
57+ def __init__ (self , data_path : Path , train : bool = False , transform = None , download : bool = False ,):
58+ super ().__init__ ()
59+
60+ self .data_path = data_path
61+ self .mnist_path = self .data_path / "MNIST"
62+ self .train = train
63+ self .transform = transform
64+ self .download = download
65+
66+ if self .download and not self ._chech_is_downloaded ():
67+ self ._download_data ()
68+
69+ self .images_path = self .mnist_path / ("train-images-idx3-ubyte" if train else "t10k-images-idx3-ubyte" )
70+ self .labels_path = self .mnist_path / ("train-labels-idx1-ubyte" if train else "t10k-labels-idx1-ubyte" )
71+
72+ labels = self ._parse_labels (train = self .train )
73+
74+ self .idx = np .where (labels < 4 )[0 ]
75+
76+ self .length = len (self .idx )
77+
78+
79+ def _parse_labels (self , train ):
80+ with open (self .labels_path , "rb" ) as f :
81+ data = np .frombuffer (f .read (), dtype = np .uint8 , offset = 8 )
82+ return data
83+
84+ def _chech_is_downloaded (self ):
85+ if self .mnist_path .exists ():
86+ required_files = ["train-images-idx3-ubyte" , "train-labels-idx1-ubyte" , "t10k-images-idx3-ubyte" , "t10k-labels-idx1-ubyte" ]
87+ if all ([(self .mnist_path / file ).exists () for file in required_files ]):
88+ print ("Data already downloaded." )
89+ return True
90+ else :
91+ return False
92+ else :
93+ self .mnist_path .mkdir (parents = True , exist_ok = True )
94+ return False
95+
96+
97+ def _download_data (self ):
98+ urls = {
99+ "train_images" : "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz" ,
100+ "train_labels" : "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz" ,
101+ "test_images" : "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz" ,
102+ "test_labels" : "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz" ,
103+ }
104+
105+ for name , url in urls .items ():
106+ file_path = os .path .join (self .mnist_path , url .split ("/" )[- 1 ])
107+ if not os .path .exists (file_path .replace (".gz" , "" )): # Avoid re-downloading
108+ urllib .request .urlretrieve (url , file_path )
109+ with gzip .open (file_path , 'rb' ) as f_in :
110+ with open (file_path .replace (".gz" , "" ), 'wb' ) as f_out :
111+ f_out .write (f_in .read ())
112+ os .remove (file_path ) # Remove compressed file
113+
114+
115+ def __len__ (self ):
116+ return self .length
117+
118+ def __getitem__ (self , index ):
119+ with open (self .labels_path , "rb" ) as f :
120+ f .seek (8 + index ) # Jump to the label position
121+ label = int .from_bytes (f .read (1 ), byteorder = "big" ) # Read 1 byte for label
122+
123+ with open (self .images_path , "rb" ) as f :
124+ f .seek (16 + index * 28 ) # Jump to image position
125+ image = np .frombuffer (f .read (28 ), dtype = np .uint8 ).reshape (28 , 28 ) # Read image data
126+
127+ if self .transform :
128+ image = self .transform (image )
129+
130+ return image , label
0 commit comments