Skip to content

Commit 952d665

Browse files
authored
Merge pull request #730 from jaybdub/batchnorm3d_import_fix
batchnorm3d import fix
2 parents 458394f + 5541ec9 commit 952d665

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

torch2trt/converters/BatchNorm3d.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,9 @@ def convert_BatchNorm3d(ctx):
2121
layer = ctx.network.add_scale(input_trt, trt.ScaleMode.CHANNEL, bias, scale, power)
2222

2323
output._trt = layer.get_output(0)
24+
25+
26+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 16, 16, 16)])
27+
@add_module_test(torch.float32, torch.device('cuda'), [(2, 3, 16, 16, 16)], max_batch_size=2)
28+
def test_BatchNorm3d_basic():
29+
return torch.nn.BatchNorm3d(3)

torch2trt/converters/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .AdaptiveAvgPool2d import *
88
from .BatchNorm1d import *
99
from .BatchNorm2d import *
10+
from .BatchNorm3d import *
1011
from .clone import *
1112
from .conv_functional import *
1213
from .Conv import *

0 commit comments

Comments
 (0)