Skip to content

Commit 3074b2a

Browse files
Merge branch 'main' into gradient
2 parents 1ff476d + fddd06a commit 3074b2a

File tree

4 files changed

+45
-14
lines changed

4 files changed

+45
-14
lines changed

RELEASENOTES.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ Releases, starting with 9/2/2021, are listed with the most recent release at the
66

77
__Breaking Changes__:
88

9+
- `torchvision.dataset.MNIST` will try more mirrors.
10+
- The thrown exception might be changed when it fails to download `MNIST`, `FashionMNIST` or `KMNIST`.
11+
912
__API Changes__:
1013

1114
- #1291 `Tensor.grad()` and `Tensor.set_grad()` have been replaced by a new property `Tensor.grad`.

src/TorchVision/dsets/CIFAR.cs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,26 @@ protected void DownloadFile(string file, string target, string baseUrl)
9393
lock (_httpClient) {
9494
using var s = _httpClient.GetStreamAsync(netPath).Result;
9595
using var fs = new FileStream(filePath, FileMode.CreateNew);
96-
s.CopyToAsync(fs).Wait();
96+
s.CopyTo(fs);
9797
}
9898
}
9999
}
100100

101+
protected void DownloadFile(string file, string target, IEnumerable<string> baseUrls)
102+
{
103+
var exceptions = new List<Exception>();
104+
foreach (var baseUrl in baseUrls) {
105+
try {
106+
DownloadFile(file, target, baseUrl);
107+
return;
108+
} catch (Exception e) {
109+
exceptions.Add(e);
110+
continue;
111+
}
112+
}
113+
throw new AggregateException($"Error downloading {file}", exceptions);
114+
}
115+
101116
protected static string JoinPaths(string directory, string file)
102117
{
103118
#if NETSTANDARD2_0_OR_GREATER

src/TorchVision/dsets/MNIST.cs

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ namespace Modules
7676
/// </summary>
7777
internal class MNIST : DatasetHelper
7878
{
79+
private static string[] Mirrors => new[] {
80+
"http://yann.lecun.com/exdb/mnist/",
81+
"https://ossci-datasets.s3.amazonaws.com/mnist/"
82+
};
83+
7984
/// <summary>
8085
/// Constructor
8186
/// </summary>
@@ -84,13 +89,13 @@ internal class MNIST : DatasetHelper
8489
/// <param name="download"></param>
8590
/// <param name="transform">Transform for input MNIST image</param>
8691
public MNIST(string root, bool train, bool download = false, torchvision.ITransform transform = null) :
87-
this(root, "mnist", train ? "train" : "t10k", "http://yann.lecun.com/exdb/mnist/", download, transform)
92+
this(root, "mnist", train ? "train" : "t10k", Mirrors, download, transform)
8893
{
8994
}
9095

91-
protected MNIST(string root, string datasetName, string prefix, string baseUrl, bool download, torchvision.ITransform transform)
96+
protected MNIST(string root, string datasetName, string prefix, IEnumerable<string> baseUrls, bool download, torchvision.ITransform transform)
9297
{
93-
if (download) Download(root, baseUrl, datasetName);
98+
if (download) Download(root, baseUrls, datasetName);
9499

95100
this.transform = transform;
96101

@@ -156,7 +161,7 @@ protected MNIST(string root, string datasetName, string prefix, string baseUrl,
156161
}
157162
}
158163

159-
private void Download(string root, string baseUrl, string dataset)
164+
private void Download(string root, IEnumerable<string> baseUrls, string dataset)
160165
{
161166
#if NETSTANDARD2_0_OR_GREATER
162167
var datasetPath = NSPath.Join(root, dataset);
@@ -171,10 +176,10 @@ private void Download(string root, string baseUrl, string dataset)
171176
Directory.CreateDirectory(sourceDir);
172177
}
173178

174-
DownloadFile("train-images-idx3-ubyte.gz", sourceDir, baseUrl);
175-
DownloadFile("train-labels-idx1-ubyte.gz", sourceDir, baseUrl);
176-
DownloadFile("t10k-images-idx3-ubyte.gz", sourceDir, baseUrl);
177-
DownloadFile("t10k-labels-idx1-ubyte.gz", sourceDir, baseUrl);
179+
DownloadFile("train-images-idx3-ubyte.gz", sourceDir, baseUrls);
180+
DownloadFile("train-labels-idx1-ubyte.gz", sourceDir, baseUrls);
181+
DownloadFile("t10k-images-idx3-ubyte.gz", sourceDir, baseUrls);
182+
DownloadFile("t10k-labels-idx1-ubyte.gz", sourceDir, baseUrls);
178183

179184
if (!Directory.Exists(targetDir)) {
180185
Directory.CreateDirectory(targetDir);
@@ -229,6 +234,10 @@ public override Dictionary<string, Tensor> GetTensor(long index)
229234

230235
internal class FashionMNIST : MNIST
231236
{
237+
private static string[] Mirrors => new[] {
238+
"https://github.com/zalandoresearch/fashion-mnist/raw/master/data/fashion/"
239+
};
240+
232241
/// <summary>
233242
/// Constructor
234243
/// </summary>
@@ -237,13 +246,17 @@ internal class FashionMNIST : MNIST
237246
/// <param name="download"></param>
238247
/// <param name="transform">Transform for input MNIST image</param>
239248
public FashionMNIST(string root, bool train, bool download = false, torchvision.ITransform transform = null) :
240-
base(root, "fashion-mnist", train ? "train" : "t10k", "https://github.com/zalandoresearch/fashion-mnist/raw/master/data/fashion/", download, transform)
249+
base(root, "fashion-mnist", train ? "train" : "t10k", Mirrors, download, transform)
241250
{
242251
}
243252
}
244253

245254
internal class KMNIST : MNIST
246255
{
256+
private static string[] Mirrors => new[] {
257+
"http://codh.rois.ac.jp/kmnist/dataset/kmnist/"
258+
};
259+
247260
/// <summary>
248261
/// Constructor
249262
/// </summary>
@@ -252,7 +265,7 @@ internal class KMNIST : MNIST
252265
/// <param name="download"></param>
253266
/// <param name="transform">Transform for input MNIST image</param>
254267
public KMNIST(string root, bool train, bool download = false, torchvision.ITransform transform = null) :
255-
base(root, "kmnist", train ? "train" : "t10k", "http://codh.rois.ac.jp/kmnist/dataset/kmnist/", download, transform)
268+
base(root, "kmnist", train ? "train" : "t10k", Mirrors, download, transform)
256269
{
257270
}
258271
}

test/TorchSharpTest/LinearAlgebra.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -404,19 +404,19 @@ public void SolveTriangularTest()
404404
var A = randn(3, 3).triu_();
405405
var b = randn(3, 4);
406406
var x = linalg.solve_triangular(A, b, upper: true);
407-
Assert.True(A.matmul(x).allclose(b, rtol: 1e-03, atol: 1e-06));
407+
Assert.True(A.matmul(x).allclose(b, rtol: 1e-03, atol: 1e-05));
408408
}
409409
{
410410
var A = randn(2, 3, 3).tril_();
411411
var b = randn(2, 3, 4);
412412
var x = linalg.solve_triangular(A, b, upper: false);
413-
Assert.True(A.matmul(x).allclose(b, rtol: 1e-03, atol: 1e-06));
413+
Assert.True(A.matmul(x).allclose(b, rtol: 1e-03, atol: 1e-05));
414414
}
415415
{
416416
var A = randn(2, 4, 4).tril_();
417417
var b = randn(2, 3, 4);
418418
var x = linalg.solve_triangular(A, b, upper: false, left: false);
419-
Assert.True(x.matmul(A).allclose(b, rtol: 1e-03, atol: 1e-06));
419+
Assert.True(x.matmul(A).allclose(b, rtol: 1e-03, atol: 1e-05));
420420
}
421421
}
422422

0 commit comments

Comments
 (0)