1+ import gzip
2+ import os
3+ import urllib .request as ur
4+ from pathlib import Path
5+ import numpy as np
6+ from torch .utils .data import Dataset
7+
8+ class MNIST_4_9 (Dataset ):
9+ def __init__ (self ,
10+ datapath : Path ,
11+ train : bool = False ,
12+ download : bool = False
13+ ):
14+ super .__init__ ()
15+ self .datapath = datapath
16+ self .mnist_path = self .datapath / "MNIST"
17+ self .train = train
18+ self .download = download
19+ self .num_classes : int = 6
20+
21+ if not self .download and not self ._already_downloaded ():
22+ raise FileNotFoundError (
23+ 'Data files are not found. Set --download-data=True to download the data'
24+ )
25+ if self .download and not self ._already_downloaded ():
26+ self ._download ()
27+
28+
29+
30+
31+ def _download (self ):
32+ urls : dict = {
33+ "train_images" : "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz" ,
34+ "train_labels" : "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz" ,
35+ "test_images" : "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz" ,
36+ "test_labels" : "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz" ,
37+ }
38+
39+
40+ for url in urls .values ():
41+ file_path : Path = os .path .join (self .mnist_path , url .split ('/' )[- 1 ])
42+ file_name : Path = file_path .replace ('.gz' ,'' )
43+ if os .path .exists (file_name ):
44+ print (f"File: { file_name } already downloaded" )
45+ else :
46+ print (f"File: { file_name } is downloading..." )
47+ ur .urlretrieve (url , file_path ) # Download file
48+ with gzip .open (file_path , 'rb' ) as infile :
49+ with open (file_name , 'wb' ) as outfile :
50+ outfile .write (infile .read ()) # Write from url to local file
51+ os .remove (file_path ) # remove .gz file
52+
53+
54+
55+ def _already_downloaded (self ):
56+ if self .mnist_path .exists ():
57+ required_files : list = [
58+ "train-images-idx3-ubyte" ,
59+ "train-labels-idx1-ubyte" ,
60+ "t10k-images-idx3-ubyte" ,
61+ "t10k-labels-idx1-ubyte" ,
62+ ]
63+ return all ([(self .mnist_path / file ).exists () for file in required_files ])
64+
65+ else :
66+ self .mnist_path .mkdir (parents = True , exist_ok = True )
67+ return False
68+
69+ def __len__ (self ):
70+ pass
71+
72+ def __getitem__ (self ):
73+ pass
74+
75+
0 commit comments