Skip to content

Commit e5817bf

Browse files
Added validation check in L1 Loss unit tests
1 parent e0bd0f3 commit e5817bf

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

test/TorchSharpTest/NN.cs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,15 +1409,16 @@ public void TestCrossEntropyLossF()
14091409
public void TestL1Loss()
14101410
{
14111411
foreach (var device in TestUtils.AvailableDevices()) {
1412-
using (Tensor input = torch.rand(new long[] { 5, 2 }, device: device))
1413-
using (Tensor target = torch.rand(new long[] { 5, 2 }, device: device)) {
1412+
using (Tensor input = torch.arange(10, dtype:float32, device: device))
1413+
using (Tensor target = torch.zeros(10, dtype:float32, device: device)) {
14141414
var outTensor = L1Loss().call(input, target);
14151415
Assert.Equal(device.type, outTensor.device_type);
14161416
var values = outTensor.data<float>().ToArray();
14171417
Assert.Multiple(
14181418
() => Assert.Empty(outTensor.shape),
14191419
() => Assert.Single(values),
1420-
() => Assert.False(float.IsNaN(values[0]))
1420+
() => Assert.False(float.IsNaN(values[0])),
1421+
() => Assert.Equal(4.5, values[0])
14211422
);
14221423
}
14231424
}
@@ -1427,15 +1428,16 @@ public void TestL1Loss()
14271428
public void TestL1LossF()
14281429
{
14291430
foreach (var device in TestUtils.AvailableDevices()) {
1430-
using (Tensor input = torch.rand(new long[] { 5, 2 }, device: device))
1431-
using (Tensor target = torch.rand(new long[] { 5, 2 }, device: device)) {
1431+
using (Tensor input = torch.arange(10, dtype:float32, device: device))
1432+
using (Tensor target = torch.zeros(10, dtype:float32, device: device)) {
14321433
var outTensor = l1_loss(input, target);
14331434
Assert.Equal(device.type, outTensor.device_type);
14341435
var values = outTensor.data<float>().ToArray();
14351436
Assert.Multiple(
14361437
() => Assert.Empty(outTensor.shape),
14371438
() => Assert.Single(values),
1438-
() => Assert.False(float.IsNaN(values[0]))
1439+
() => Assert.False(float.IsNaN(values[0])),
1440+
() => Assert.Equal(4.5, values[0])
14391441
);
14401442
}
14411443
}

0 commit comments

Comments
 (0)