Skip to content

Commit 98669fd

Browse files
committed
handle conv1d
1 parent bfe5bc5 commit 98669fd

File tree

2 files changed

+83
-35
lines changed

2 files changed

+83
-35
lines changed

tests/converter_tests/test_converters.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -179,22 +179,28 @@ def test_radd_float():
179179
# TODO: radd, add, iadd
180180

181181

182+
@pytest.mark.parametrize("with_conv", [True, False])
183+
@pytest.mark.parametrize("nd", [1,2,3])
184+
def test_batch_norm_nd(nd, with_conv):
185+
modules = []
186+
if nd == 1:
187+
if with_conv:
188+
modules.append(nn.Conv1d(3, 3, 1)) # with conv, because scale layer not implemented sometimes.
189+
modules.append(nn.BatchNorm1d(3))
190+
if nd == 2:
191+
if with_conv:
192+
modules.append(nn.Conv2d(3, 3, 1))
193+
modules.append(nn.BatchNorm2d(3))
194+
if nd == 3:
195+
if with_conv:
196+
modules.append(nn.Conv3d(3, 3, 1))
197+
modules.append(nn.BatchNorm3d(3))
198+
199+
module = nn.Sequential(*modules).cuda().eval()
200+
201+
input_size = [2, 3] + [4] * nd
182202

183-
def test_batch_norm_1d():
184-
module = nn.BatchNorm1d(3).cuda().eval()
185-
inputs = [torch.randn(2, 3, 4).cuda()]
186-
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)
187-
188-
189-
def test_batch_norm_2d():
190-
module = nn.BatchNorm2d(3).cuda().eval()
191-
inputs = [torch.randn(2, 3, 4, 4).cuda()]
192-
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)
193-
194-
195-
def test_batch_norm_3d():
196-
module = nn.BatchNorm3d(3).cuda().eval()
197-
inputs = [torch.randn(2, 3, 4, 4, 4).cuda()]
203+
inputs = [torch.randn(*input_size).cuda()]
198204
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)
199205

200206

torch2trt/converters/native_converters.py

Lines changed: 62 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -342,10 +342,6 @@ def convert_gt(ctx):
342342

343343

344344
@tensorrt_converter('torch.nn.functional.conv1d')
345-
def convert_conv1d(ctx):
346-
raise NotImplementedError
347-
348-
349345
@tensorrt_converter('torch.nn.functional.conv2d')
350346
@tensorrt_converter('torch.nn.functional.conv3d')
351347
def convert_conv2d3d(ctx):
@@ -376,31 +372,53 @@ def convert_conv2d3d(ctx):
376372
if not isinstance(dilation, tuple):
377373
dilation = (dilation, ) * input_dim
378374

375+
379376
kernel = weight.detach().cpu().numpy()
380377

381378
if bias is not None:
382379
bias = bias.detach().cpu().numpy()
383380

384-
layer = ctx.network.add_convolution_nd(
385-
input=input_trt,
381+
# Handle reshape 1D to 2D
382+
if input_dim == 1:
383+
kernel_size = kernel_size + (1,)
384+
stride = stride + (1,)
385+
padding = padding + (0,)
386+
dilation = dilation + (1,)
387+
unsqueeze_layer = ctx.network.add_shuffle(input_trt)
388+
set_layer_precision(ctx, unsqueeze_layer)
389+
unsqueeze_layer.reshape_dims = tuple([0]*input.ndim) + (1,)
390+
conv_input = unsqueeze_layer.get_output(0)
391+
else:
392+
conv_input = input_trt
393+
394+
395+
conv_layer = ctx.network.add_convolution_nd(
396+
input=conv_input,
386397
num_output_maps=out_channels,
387398
kernel_shape=kernel_size,
388399
kernel=kernel,
389400
bias=bias)
390-
layer.stride_nd = stride
391-
layer.padding_nd = padding
392-
layer.dilation_nd = dilation
401+
conv_layer.stride_nd = stride
402+
conv_layer.padding_nd = padding
403+
conv_layer.dilation_nd = dilation
393404

394405
if groups is not None:
395-
layer.num_groups = groups
406+
conv_layer.num_groups = groups
396407

397-
output._trt = layer.get_output(0)
408+
output._trt = conv_layer.get_output(0)
409+
410+
# Handle reshape 2D backt o 1D
411+
if input_dim == 1:
412+
squeeze_layer = ctx.network.add_shuffle(conv_layer.get_output(0))
413+
set_layer_precision(ctx, squeeze_layer)
414+
squeeze_layer.reshape_dims = tuple([0] * input.ndim)
415+
output._trt = squeeze_layer.get_output(0)
416+
else:
417+
output._trt = conv_layer.get_output(0)
398418

399419

400-
@tensorrt_converter('torch.nn.functional.conv_transpose1d')
401-
def convert_conv_transpose1d(ctx):
402-
raise NotImplementedError
403420

421+
@tensorrt_converter('torch.nn.functional.conv_transpose1d')
404422
@tensorrt_converter('torch.nn.functional.conv_transpose2d')
405423
@tensorrt_converter('torch.nn.functional.conv_transpose3d')
406424
def convert_conv_transpose2d3d(ctx):
@@ -442,19 +460,43 @@ def convert_conv_transpose2d3d(ctx):
442460

443461
bias = trt.Weights(torch_dtype_to_trt(weight.dtype))
444462

445-
layer = ctx.network.add_deconvolution_nd(
446-
input=input_trt,
463+
464+
# Handle reshape 1D to 2D
465+
if input_dim == 1:
466+
kernel_size = kernel_size + (1,)
467+
stride = stride + (1,)
468+
padding = padding + (0,)
469+
dilation = dilation + (1,)
470+
unsqueeze_layer = ctx.network.add_shuffle(input_trt)
471+
set_layer_precision(ctx, unsqueeze_layer)
472+
unsqueeze_layer.reshape_dims = tuple([0]*input.ndim) + (1,)
473+
conv_input = unsqueeze_layer.get_output(0)
474+
else:
475+
conv_input = input_trt
476+
477+
478+
conv_layer = ctx.network.add_deconvolution_nd(
479+
input=conv_input,
447480
num_output_maps=out_channels,
448481
kernel_shape=kernel_size,
449482
kernel=kernel,
450483
bias=bias)
451-
layer.stride_nd = stride
452-
layer.padding_nd = padding
484+
conv_layer.stride_nd = stride
485+
conv_layer.padding_nd = padding
453486

454487
if groups is not None:
455-
layer.num_groups = groups
488+
conv_layer.num_groups = groups
489+
490+
491+
# Handle reshape 2D backt o 1D
492+
if input_dim == 1:
493+
squeeze_layer = ctx.network.add_shuffle(conv_layer.get_output(0))
494+
set_layer_precision(ctx, squeeze_layer)
495+
squeeze_layer.reshape_dims = tuple([0] * input.ndim)
496+
output._trt = squeeze_layer.get_output(0)
497+
else:
498+
output._trt = conv_layer.get_output(0)
456499

457-
output._trt = layer.get_output(0)
458500

459501

460502
@tensorrt_converter('torch.div')

0 commit comments

Comments
 (0)