@@ -325,8 +325,8 @@ public void FunctionalBilinearNoBias()
325
325
Assert . Equal ( 40 , forward . shape [ 1 ] ) ;
326
326
Assert . Equal ( device . type , forward . device_type ) ;
327
327
}
328
- }
329
-
328
+ }
329
+
330
330
[ Fact ]
331
331
public void TestLinearFused ( )
332
332
{
@@ -3272,7 +3272,7 @@ public void TestCustomComponentName()
3272
3272
Assert . True ( sd . ContainsKey ( "_linear2.weight" ) ) ;
3273
3273
3274
3274
// The field names are also retrieved in the `_toEpilogue` function, so we want to make sure
3275
- // that everything works after calling a `.to` function.
3275
+ // that everything works after calling a `.to` function.
3276
3276
model = model . to ( ScalarType . BFloat16 ) ;
3277
3277
3278
3278
sd = model . state_dict ( ) ;
@@ -4496,6 +4496,26 @@ public void TestBatchNorm2D()
4496
4496
}
4497
4497
}
4498
4498
4499
+
4500
+
4501
+ [ Fact ]
4502
+ public void TestMovingBatchNorm2D ( )
4503
+ {
4504
+ Device ? device = torch . mps_is_available ( ) ? torch . MPS : torch . cuda_is_available ( ) ? torch . CUDA : null ;
4505
+
4506
+ if ( device is not null ) {
4507
+ using ( var pool = BatchNorm2d ( 32 ) ) {
4508
+ Assert . NotNull ( pool . num_batches_tracked ) ;
4509
+ Assert . False ( pool . num_batches_tracked . IsInvalid ) ;
4510
+
4511
+ pool . to ( device ) ;
4512
+
4513
+ Assert . NotNull ( pool . num_batches_tracked ) ;
4514
+ Assert . False ( pool . num_batches_tracked . IsInvalid ) ;
4515
+ }
4516
+ }
4517
+ }
4518
+
4499
4519
[ Fact ]
4500
4520
public void TestBatchNorm2dWeightAndBias ( )
4501
4521
{
@@ -6713,7 +6733,7 @@ public void TestModulePostHooks()
6713
6733
}
6714
6734
6715
6735
[ Fact ]
6716
- public void TestCustomParameterLessModule ( )
6736
+ public void TestCustomParameterLessModule ( )
6717
6737
{
6718
6738
var cnp = new CustomNoParameters ( "test" ) ;
6719
6739
0 commit comments