@@ -76,6 +76,11 @@ namespace Modules
76
76
/// </summary>
77
77
internal class MNIST : DatasetHelper
78
78
{
79
+ private static string [ ] Mirrors => new [ ] {
80
+ "http://yann.lecun.com/exdb/mnist/" ,
81
+ "https://ossci-datasets.s3.amazonaws.com/mnist/"
82
+ } ;
83
+
79
84
/// <summary>
80
85
/// Constructor
81
86
/// </summary>
@@ -84,13 +89,13 @@ internal class MNIST : DatasetHelper
84
89
/// <param name="download"></param>
85
90
/// <param name="transform">Transform for input MNIST image</param>
86
91
public MNIST ( string root , bool train , bool download = false , torchvision . ITransform transform = null ) :
87
- this ( root , "mnist" , train ? "train" : "t10k" , "http://yann.lecun.com/exdb/mnist/" , download , transform )
92
+ this ( root , "mnist" , train ? "train" : "t10k" , Mirrors , download , transform )
88
93
{
89
94
}
90
95
91
- protected MNIST ( string root , string datasetName , string prefix , string baseUrl , bool download , torchvision . ITransform transform )
96
+ protected MNIST ( string root , string datasetName , string prefix , IEnumerable < string > baseUrls , bool download , torchvision . ITransform transform )
92
97
{
93
- if ( download ) Download ( root , baseUrl , datasetName ) ;
98
+ if ( download ) Download ( root , baseUrls , datasetName ) ;
94
99
95
100
this . transform = transform ;
96
101
@@ -156,7 +161,7 @@ protected MNIST(string root, string datasetName, string prefix, string baseUrl,
156
161
}
157
162
}
158
163
159
- private void Download ( string root , string baseUrl , string dataset )
164
+ private void Download ( string root , IEnumerable < string > baseUrls , string dataset )
160
165
{
161
166
#if NETSTANDARD2_0_OR_GREATER
162
167
var datasetPath = NSPath . Join ( root , dataset ) ;
@@ -171,10 +176,10 @@ private void Download(string root, string baseUrl, string dataset)
171
176
Directory . CreateDirectory ( sourceDir ) ;
172
177
}
173
178
174
- DownloadFile ( "train-images-idx3-ubyte.gz" , sourceDir , baseUrl ) ;
175
- DownloadFile ( "train-labels-idx1-ubyte.gz" , sourceDir , baseUrl ) ;
176
- DownloadFile ( "t10k-images-idx3-ubyte.gz" , sourceDir , baseUrl ) ;
177
- DownloadFile ( "t10k-labels-idx1-ubyte.gz" , sourceDir , baseUrl ) ;
179
+ DownloadFile ( "train-images-idx3-ubyte.gz" , sourceDir , baseUrls ) ;
180
+ DownloadFile ( "train-labels-idx1-ubyte.gz" , sourceDir , baseUrls ) ;
181
+ DownloadFile ( "t10k-images-idx3-ubyte.gz" , sourceDir , baseUrls ) ;
182
+ DownloadFile ( "t10k-labels-idx1-ubyte.gz" , sourceDir , baseUrls ) ;
178
183
179
184
if ( ! Directory . Exists ( targetDir ) ) {
180
185
Directory . CreateDirectory ( targetDir ) ;
@@ -229,6 +234,10 @@ public override Dictionary<string, Tensor> GetTensor(long index)
229
234
230
235
internal class FashionMNIST : MNIST
231
236
{
237
+ private static string [ ] Mirrors => new [ ] {
238
+ "https://github.com/zalandoresearch/fashion-mnist/raw/master/data/fashion/"
239
+ } ;
240
+
232
241
/// <summary>
233
242
/// Constructor
234
243
/// </summary>
@@ -237,13 +246,17 @@ internal class FashionMNIST : MNIST
237
246
/// <param name="download"></param>
238
247
/// <param name="transform">Transform for input MNIST image</param>
239
248
public FashionMNIST ( string root , bool train , bool download = false , torchvision . ITransform transform = null ) :
240
- base ( root , "fashion-mnist" , train ? "train" : "t10k" , "https://github.com/zalandoresearch/fashion-mnist/raw/master/data/fashion/" , download , transform )
249
+ base ( root , "fashion-mnist" , train ? "train" : "t10k" , Mirrors , download , transform )
241
250
{
242
251
}
243
252
}
244
253
245
254
internal class KMNIST : MNIST
246
255
{
256
+ private static string [ ] Mirrors => new [ ] {
257
+ "http://codh.rois.ac.jp/kmnist/dataset/kmnist/"
258
+ } ;
259
+
247
260
/// <summary>
248
261
/// Constructor
249
262
/// </summary>
@@ -252,7 +265,7 @@ internal class KMNIST : MNIST
252
265
/// <param name="download"></param>
253
266
/// <param name="transform">Transform for input MNIST image</param>
254
267
public KMNIST ( string root , bool train , bool download = false , torchvision . ITransform transform = null ) :
255
- base ( root , "kmnist" , train ? "train" : "t10k" , "http://codh.rois.ac.jp/kmnist/dataset/kmnist/" , download , transform )
268
+ base ( root , "kmnist" , train ? "train" : "t10k" , Mirrors , download , transform )
256
269
{
257
270
}
258
271
}
0 commit comments