Skip to content

Commit 59152f3

Browse files
Added unit test for BatchNorm moved to MPS or CUDA.
1 parent 1b33637 commit 59152f3

File tree

1 file changed

+24
-4
lines changed

1 file changed

+24
-4
lines changed

test/TorchSharpTest/NN.cs

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -325,8 +325,8 @@ public void FunctionalBilinearNoBias()
325325
Assert.Equal(40, forward.shape[1]);
326326
Assert.Equal(device.type, forward.device_type);
327327
}
328-
}
329-
328+
}
329+
330330
[Fact]
331331
public void TestLinearFused()
332332
{
@@ -3272,7 +3272,7 @@ public void TestCustomComponentName()
32723272
Assert.True(sd.ContainsKey("_linear2.weight"));
32733273

32743274
// The field names are also retrieved in the `_toEpilogue` function, so we want to make sure
3275-
// that everything works after calling a `.to` function.
3275+
// that everything works after calling a `.to` function.
32763276
model = model.to(ScalarType.BFloat16);
32773277

32783278
sd = model.state_dict();
@@ -4496,6 +4496,26 @@ public void TestBatchNorm2D()
44964496
}
44974497
}
44984498

4499+
4500+
4501+
[Fact]
4502+
public void TestMovingBatchNorm2D()
4503+
{
4504+
Device? device = torch.mps_is_available() ? torch.MPS : torch.cuda_is_available() ? torch.CUDA : null;
4505+
4506+
if (device is not null) {
4507+
using (var pool = BatchNorm2d(32)) {
4508+
Assert.NotNull(pool.num_batches_tracked);
4509+
Assert.False(pool.num_batches_tracked.IsInvalid);
4510+
4511+
pool.to(device);
4512+
4513+
Assert.NotNull(pool.num_batches_tracked);
4514+
Assert.False(pool.num_batches_tracked.IsInvalid);
4515+
}
4516+
}
4517+
}
4518+
44994519
[Fact]
45004520
public void TestBatchNorm2dWeightAndBias()
45014521
{
@@ -6713,7 +6733,7 @@ public void TestModulePostHooks()
67136733
}
67146734

67156735
[Fact]
6716-
public void TestCustomParameterLessModule()
6736+
public void TestCustomParameterLessModule()
67176737
{
67186738
var cnp = new CustomNoParameters("test");
67196739

0 commit comments

Comments
 (0)