@@ -1409,15 +1409,16 @@ public void TestCrossEntropyLossF()
1409
1409
public void TestL1Loss ( )
1410
1410
{
1411
1411
foreach ( var device in TestUtils . AvailableDevices ( ) ) {
1412
- using ( Tensor input = torch . rand ( new long [ ] { 5 , 2 } , device : device ) )
1413
- using ( Tensor target = torch . rand ( new long [ ] { 5 , 2 } , device : device ) ) {
1412
+ using ( Tensor input = torch . arange ( 10 , dtype : float32 , device : device ) )
1413
+ using ( Tensor target = torch . zeros ( 10 , dtype : float32 , device : device ) ) {
1414
1414
var outTensor = L1Loss ( ) . call ( input , target ) ;
1415
1415
Assert . Equal ( device . type , outTensor . device_type ) ;
1416
1416
var values = outTensor . data < float > ( ) . ToArray ( ) ;
1417
1417
Assert . Multiple (
1418
1418
( ) => Assert . Empty ( outTensor . shape ) ,
1419
1419
( ) => Assert . Single ( values ) ,
1420
- ( ) => Assert . False ( float . IsNaN ( values [ 0 ] ) )
1420
+ ( ) => Assert . False ( float . IsNaN ( values [ 0 ] ) ) ,
1421
+ ( ) => Assert . Equal ( 4.5 , values [ 0 ] )
1421
1422
) ;
1422
1423
}
1423
1424
}
@@ -1427,15 +1428,16 @@ public void TestL1Loss()
1427
1428
public void TestL1LossF ( )
1428
1429
{
1429
1430
foreach ( var device in TestUtils . AvailableDevices ( ) ) {
1430
- using ( Tensor input = torch . rand ( new long [ ] { 5 , 2 } , device : device ) )
1431
- using ( Tensor target = torch . rand ( new long [ ] { 5 , 2 } , device : device ) ) {
1431
+ using ( Tensor input = torch . arange ( 10 , dtype : float32 , device : device ) )
1432
+ using ( Tensor target = torch . zeros ( 10 , dtype : float32 , device : device ) ) {
1432
1433
var outTensor = l1_loss ( input , target ) ;
1433
1434
Assert . Equal ( device . type , outTensor . device_type ) ;
1434
1435
var values = outTensor . data < float > ( ) . ToArray ( ) ;
1435
1436
Assert . Multiple (
1436
1437
( ) => Assert . Empty ( outTensor . shape ) ,
1437
1438
( ) => Assert . Single ( values ) ,
1438
- ( ) => Assert . False ( float . IsNaN ( values [ 0 ] ) )
1439
+ ( ) => Assert . False ( float . IsNaN ( values [ 0 ] ) ) ,
1440
+ ( ) => Assert . Equal ( 4.5 , values [ 0 ] )
1439
1441
) ;
1440
1442
}
1441
1443
}
0 commit comments