Skip to content

Commit 0293f36

Browse files
Merge pull request #1360 from NiklasGustafsson/bugs
Address issue #1359
2 parents 26240d5 + baf3061 commit 0293f36

File tree

4 files changed

+18
-10
lines changed

4 files changed

+18
-10
lines changed

RELEASENOTES.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@ Releases, starting with 9/2/2021, are listed with the most recent release at the
44

55
# NuGet Version 0.102.6
66

7+
__Bug Fixes__:
8+
9+
#1359 torch.nn.functional.l1_loss computes a criterion with the MSE, not the MAE.<br/>
10+
11+
# NuGet Version 0.102.6
12+
713
__Breaking Changes__:
814

915
When creating a tensor from a 1-D array, and passing in a shape, there is now an ambiguity between the IList and Memory overloads of `torch.tensor()`. The ambiguity is resolved by removing the `dimensions` argument if it is redundant, or by an explicit cast to IList if it is not.

build/BranchInfo.props

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
<PropertyGroup>
33
<MajorVersion>0</MajorVersion>
44
<MinorVersion>102</MinorVersion>
5-
<PatchVersion>7</PatchVersion>
6-
<PreviousPackageVersion>0.102.6</PreviousPackageVersion>
5+
<PatchVersion>8</PatchVersion>
6+
<PreviousPackageVersion>0.102.7</PreviousPackageVersion>
77
</PropertyGroup>
88

99
</Project>

src/Native/LibTorchSharp/THSLoss.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,10 @@ Tensor THSNN_kl_div_loss(const Tensor input, const Tensor target, const int64_t
9494
Tensor THSNN_l1_loss(const Tensor input, const Tensor target, const int64_t reduction)
9595
{
9696
CATCH_RETURN_Tensor(
97-
auto opts = torch::nn::functional::MSELossFuncOptions();
97+
auto opts = torch::nn::functional::L1LossFuncOptions();
9898
ApplyReduction(opts, reduction);
9999

100-
res = ResultTensor(torch::nn::functional::mse_loss(*input, *target, opts));
100+
res = ResultTensor(torch::nn::functional::l1_loss(*input, *target, opts));
101101
)
102102
}
103103

test/TorchSharpTest/NN.cs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,15 +1409,16 @@ public void TestCrossEntropyLossF()
14091409
public void TestL1Loss()
14101410
{
14111411
foreach (var device in TestUtils.AvailableDevices()) {
1412-
using (Tensor input = torch.rand(new long[] { 5, 2 }, device: device))
1413-
using (Tensor target = torch.rand(new long[] { 5, 2 }, device: device)) {
1412+
using (Tensor input = torch.arange(10, dtype:float32, device: device))
1413+
using (Tensor target = torch.zeros(10, dtype:float32, device: device)) {
14141414
var outTensor = L1Loss().call(input, target);
14151415
Assert.Equal(device.type, outTensor.device_type);
14161416
var values = outTensor.data<float>().ToArray();
14171417
Assert.Multiple(
14181418
() => Assert.Empty(outTensor.shape),
14191419
() => Assert.Single(values),
1420-
() => Assert.False(float.IsNaN(values[0]))
1420+
() => Assert.False(float.IsNaN(values[0])),
1421+
() => Assert.Equal(4.5, values[0])
14211422
);
14221423
}
14231424
}
@@ -1427,15 +1428,16 @@ public void TestL1Loss()
14271428
public void TestL1LossF()
14281429
{
14291430
foreach (var device in TestUtils.AvailableDevices()) {
1430-
using (Tensor input = torch.rand(new long[] { 5, 2 }, device: device))
1431-
using (Tensor target = torch.rand(new long[] { 5, 2 }, device: device)) {
1431+
using (Tensor input = torch.arange(10, dtype:float32, device: device))
1432+
using (Tensor target = torch.zeros(10, dtype:float32, device: device)) {
14321433
var outTensor = l1_loss(input, target);
14331434
Assert.Equal(device.type, outTensor.device_type);
14341435
var values = outTensor.data<float>().ToArray();
14351436
Assert.Multiple(
14361437
() => Assert.Empty(outTensor.shape),
14371438
() => Assert.Single(values),
1438-
() => Assert.False(float.IsNaN(values[0]))
1439+
() => Assert.False(float.IsNaN(values[0])),
1440+
() => Assert.Equal(4.5, values[0])
14391441
);
14401442
}
14411443
}

0 commit comments

Comments
 (0)