1- from pathlib import Path
2-
3- from torch .utils .data import Dataset
4- import numpy as np
5- import urllib .request
61import gzip
72import os
3+ import urllib .request
4+ from pathlib import Path
85
6+ import numpy as np
7+ from torch .utils .data import Dataset
98
109
1110class MNISTDataset0_3 (Dataset ):
@@ -54,39 +53,56 @@ class MNISTDataset0_3(Dataset):
5453 __getitem__(index)
5554 Returns the image and label at the specified index.
5655 """
57- def __init__ (self , data_path : Path , train : bool = False , transform = None , download : bool = False ,):
56+
57+ def __init__ (
58+ self ,
59+ data_path : Path ,
60+ train : bool = False ,
61+ transform = None ,
62+ download : bool = False ,
63+ ):
5864 super ().__init__ ()
59-
65+
6066 self .data_path = data_path
6167 self .mnist_path = self .data_path / "MNIST"
6268 self .train = train
6369 self .transform = transform
6470 self .download = download
6571 self .num_classes = 4
66-
72+
6773 if not self .download and not self ._chech_is_downloaded ():
68- raise ValueError ("Data not found. Set --download-data=True to download the data." )
74+ raise ValueError (
75+ "Data not found. Set --download-data=True to download the data."
76+ )
6977 if self .download and not self ._chech_is_downloaded ():
7078 self ._download_data ()
71-
72- self .images_path = self .mnist_path / ("train-images-idx3-ubyte" if train else "t10k-images-idx3-ubyte" )
73- self .labels_path = self .mnist_path / ("train-labels-idx1-ubyte" if train else "t10k-labels-idx1-ubyte" )
74-
79+
80+ self .images_path = self .mnist_path / (
81+ "train-images-idx3-ubyte" if train else "t10k-images-idx3-ubyte"
82+ )
83+ self .labels_path = self .mnist_path / (
84+ "train-labels-idx1-ubyte" if train else "t10k-labels-idx1-ubyte"
85+ )
86+
7587 labels = self ._parse_labels (train = self .train )
76-
77- self .idx = np .where (labels < 4 )[0 ]
78-
88+
89+ self .idx = np .where (labels < 4 )[0 ]
90+
7991 self .length = len (self .idx )
80-
81-
92+
8293 def _parse_labels (self , train ):
8394 with open (self .labels_path , "rb" ) as f :
8495 data = np .frombuffer (f .read (), dtype = np .uint8 , offset = 8 )
8596 return data
86-
97+
8798 def _chech_is_downloaded (self ):
8899 if self .mnist_path .exists ():
89- required_files = ["train-images-idx3-ubyte" , "train-labels-idx1-ubyte" , "t10k-images-idx3-ubyte" , "t10k-labels-idx1-ubyte" ]
100+ required_files = [
101+ "train-images-idx3-ubyte" ,
102+ "train-labels-idx1-ubyte" ,
103+ "t10k-images-idx3-ubyte" ,
104+ "t10k-labels-idx1-ubyte" ,
105+ ]
90106 if all ([(self .mnist_path / file ).exists () for file in required_files ]):
91107 print ("MNIST Dataset already downloaded." )
92108 return True
@@ -95,26 +111,24 @@ def _chech_is_downloaded(self):
95111 else :
96112 self .mnist_path .mkdir (parents = True , exist_ok = True )
97113 return False
98-
99-
114+
100115 def _download_data (self ):
101116 urls = {
102- "train_images" : "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz" ,
103- "train_labels" : "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz" ,
104- "test_images" : "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz" ,
105- "test_labels" : "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz" ,
106- }
107-
117+ "train_images" : "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz" ,
118+ "train_labels" : "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz" ,
119+ "test_images" : "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz" ,
120+ "test_labels" : "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz" ,
121+ }
122+
108123 for name , url in urls .items ():
109124 file_path = os .path .join (self .mnist_path , url .split ("/" )[- 1 ])
110125 if not os .path .exists (file_path .replace (".gz" , "" )): # Avoid re-downloading
111126 urllib .request .urlretrieve (url , file_path )
112- with gzip .open (file_path , 'rb' ) as f_in :
113- with open (file_path .replace (".gz" , "" ), 'wb' ) as f_out :
127+ with gzip .open (file_path , "rb" ) as f_in :
128+ with open (file_path .replace (".gz" , "" ), "wb" ) as f_out :
114129 f_out .write (f_in .read ())
115130 os .remove (file_path ) # Remove compressed file
116131
117-
118132 def __len__ (self ):
119133 return self .length
120134
@@ -124,12 +138,14 @@ def __getitem__(self, index):
124138 label = int .from_bytes (f .read (1 ), byteorder = "big" ) # Read 1 byte for label
125139
126140 with open (self .images_path , "rb" ) as f :
127- f .seek (16 + index * 28 * 28 ) # Jump to image position
128- image = np .frombuffer (f .read (28 * 28 ), dtype = np .uint8 ).reshape (28 , 28 ) # Read image data
129-
130- image = np .expand_dims (image , axis = 0 ) # Add channel dimension
131-
141+ f .seek (16 + index * 28 * 28 ) # Jump to image position
142+ image = np .frombuffer (f .read (28 * 28 ), dtype = np .uint8 ).reshape (
143+ 28 , 28
144+ ) # Read image data
145+
146+ image = np .expand_dims (image , axis = 0 ) # Add channel dimension
147+
132148 if self .transform :
133149 image = self .transform (image )
134-
135- return image , label
150+
151+ return image , label
0 commit comments