Skip to content

Commit 5b11bac

Browse files
Merge pull request #1280 from yueyinqiu/1279
correct torch.finfo
2 parents 7ef79ff + aa2d2c5 commit 5b11bac

File tree

4 files changed

+125
-25
lines changed

4 files changed

+125
-25
lines changed

RELEASENOTES.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ Releases, starting with 9/2/2021, are listed with the most recent release at the
44

55
# NuGet Version 0.102.4
66

7+
__Breaking Changes__:
8+
9+
Correct `torch.finfo`. (`torch.set_default_dtype`, `Categorical.entropy`, `_CorrCholesky.check`, `Distribution.ClampProbs`, `FisherSnedecor.rsample`, `Gamma.rsample`, `Geometric.rsample`, `distributions.Gumbel`, `Laplace.rsample`, `SigmoidTransform._call` and `SigmoidTransform._inverse` are influenced.)<br/>
10+
711
__API Changes__:
812

913
#1284 make `torch.unique` and `torch.unique_consecutive` public.<br/>

src/TorchSharp/Tensor/Tensor.cs

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7226,37 +7226,63 @@ public struct FInfo
72267226
public double max;
72277227
public double min;
72287228
public double tiny;
7229+
public double smallest_normal;
7230+
public double resolution;
72297231
}
72307232

7231-
public static FInfo finfo(ScalarType dtype)
7233+
public static FInfo finfo()
72327234
{
7233-
if (!is_floating_point(dtype) && !is_complex(dtype))
7234-
throw new ArgumentException("'dtype' must be floating point or complex");
7235-
7236-
if (dtype == ScalarType.ComplexFloat32)
7237-
dtype = ScalarType.Float32;
7238-
if (dtype == ScalarType.ComplexFloat64)
7239-
dtype = ScalarType.Float64;
7240-
7241-
FInfo result = new FInfo();
7235+
return finfo(default_dtype);
7236+
}
72427237

7238+
public static FInfo finfo(ScalarType dtype)
7239+
{
72437240
switch (dtype) {
7241+
case ScalarType.BFloat16:
7242+
return new FInfo() {
7243+
bits = 16,
7244+
eps = 0.0078125,
7245+
max = 3.3895313892515355e+38,
7246+
min = -3.3895313892515355e+38,
7247+
tiny = 1.1754943508222875e-38,
7248+
smallest_normal = 1.1754943508222875e-38,
7249+
resolution = 0.01
7250+
};
7251+
case ScalarType.Float16:
7252+
return new FInfo() {
7253+
bits = 16,
7254+
eps = 0.0009765625,
7255+
max = 65504.0,
7256+
min = -65504.0,
7257+
tiny = 6.103515625e-05,
7258+
smallest_normal = 6.103515625e-05,
7259+
resolution = 0.001
7260+
};
72447261
case ScalarType.Float32:
7245-
result.bits = 32;
7246-
result.min = float.MinValue;
7247-
result.max = float.MaxValue;
7248-
result.eps = float.Epsilon;
7249-
result.tiny = float.Epsilon;
7250-
break;
7262+
case ScalarType.ComplexFloat32:
7263+
return new FInfo() {
7264+
bits = 32,
7265+
eps = 1.1920928955078125e-07,
7266+
max = 3.4028234663852886e+38,
7267+
min = -3.4028234663852886e+38,
7268+
tiny = 1.1754943508222875e-38,
7269+
smallest_normal = 1.1754943508222875e-38,
7270+
resolution = 1e-06
7271+
};
72517272
case ScalarType.Float64:
7252-
result.bits = 64;
7253-
result.min = double.MinValue;
7254-
result.max = double.MaxValue;
7255-
result.eps = double.Epsilon;
7256-
result.tiny = double.Epsilon;
7257-
break;
7258-
}
7259-
return result;
7273+
case ScalarType.ComplexFloat64:
7274+
return new FInfo() {
7275+
bits = 64,
7276+
eps = 2.220446049250313e-16,
7277+
max = 1.7976931348623157e+308,
7278+
min = -1.7976931348623157e+308,
7279+
tiny = 2.2250738585072014e-308,
7280+
smallest_normal = 2.2250738585072014e-308,
7281+
resolution = 1e-15
7282+
};
7283+
default:
7284+
throw new ArgumentException("'dtype' must be floating point or complex");
7285+
}
72607286
}
72617287

72627288
public static bool is_integral(ScalarType type)

src/TorchSharp/Tensor/torch.Tensors.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,13 @@ public static partial class torch
4646
/// The default floating point dtype is initially torch.float32.
4747
/// </summary>
4848
/// <param name="dtype"></param>
49-
public static void set_default_dtype(ScalarType dtype) { default_dtype = dtype; }
49+
public static void set_default_dtype(ScalarType dtype)
50+
{
51+
if (!dtype.IsFloatingPoint()) {
52+
throw new ArgumentException("only floating-point types are supported as the default type");
53+
}
54+
default_dtype = dtype;
55+
}
5056

5157
// https://pytorch.org/docs/stable/generated/torch.get_default_dtype
5258
/// <summary>

test/TorchSharpTest/TestTorchSharp.cs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,70 @@ namespace TorchSharp
1515
[Collection("Sequential")]
1616
public class TestTorch
1717
{
18+
[Fact]
19+
public void FInfoTest()
20+
{
21+
static void AssertScalarEqual(Tensor expected, Tensor actual)
22+
{
23+
Assert.Equal((double)expected, (double)actual);
24+
}
25+
static void AssertScalarNotEqual(Tensor expected, Tensor actual)
26+
{
27+
Assert.NotEqual((double)expected, (double)actual);
28+
}
29+
30+
var floatingTypes = new[] {
31+
ScalarType.Float16, ScalarType.Float32, ScalarType.Float64, ScalarType.BFloat16
32+
};
33+
foreach (var scalarType in floatingTypes) {
34+
var info = finfo(scalarType);
35+
36+
var zeroPointFour = tensor(0.4, scalarType);
37+
var zeroPointNine = tensor(0.9, scalarType);
38+
var one = tensor(1, scalarType);
39+
var zero = tensor(1, scalarType);
40+
Assert.Equal(one.dtype.ElementSize() * 8, info.bits);
41+
42+
var eps = tensor(info.eps, scalarType);
43+
AssertScalarNotEqual(one, one + eps);
44+
AssertScalarEqual(one + eps, one + eps * zeroPointNine);
45+
AssertScalarEqual(one, one + eps * zeroPointFour);
46+
47+
var max = tensor(info.max, scalarType);
48+
AssertScalarEqual(max, max + eps);
49+
50+
var min = tensor(info.min, scalarType);
51+
AssertScalarEqual(-max, min);
52+
AssertScalarEqual(min, min - eps);
53+
54+
var tiny = tensor(info.tiny, scalarType);
55+
// not sure how to test for tiny.
56+
57+
var smallest_normal = tensor(info.smallest_normal, scalarType);
58+
AssertScalarEqual(tiny, smallest_normal);
59+
60+
var resolution = tensor(info.resolution, scalarType);
61+
// not sure how to test for resolution.
62+
}
63+
64+
var complexTypes = new[] {
65+
(ScalarType.ComplexFloat32, ScalarType.Float32),
66+
(ScalarType.ComplexFloat64, ScalarType.Float64)
67+
};
68+
foreach (var (complex, floating) in complexTypes) {
69+
var c = finfo(complex);
70+
var f = finfo(floating);
71+
72+
Assert.Equal(f.bits, c.bits);
73+
Assert.Equal(f.eps, c.eps);
74+
Assert.Equal(f.max, c.max);
75+
Assert.Equal(f.min, c.min);
76+
Assert.Equal(f.tiny, c.tiny);
77+
Assert.Equal(f.smallest_normal, c.smallest_normal);
78+
Assert.Equal(f.resolution, c.resolution);
79+
}
80+
}
81+
1882
[Fact]
1983
public void EnumEquivalence()
2084
{

0 commit comments

Comments
 (0)