@@ -342,10 +342,6 @@ def convert_gt(ctx):
342342
343343
344344@tensorrt_converter ('torch.nn.functional.conv1d' )
345- def convert_conv1d (ctx ):
346- raise NotImplementedError
347-
348-
349345@tensorrt_converter ('torch.nn.functional.conv2d' )
350346@tensorrt_converter ('torch.nn.functional.conv3d' )
351347def convert_conv2d3d (ctx ):
@@ -376,31 +372,53 @@ def convert_conv2d3d(ctx):
376372 if not isinstance (dilation , tuple ):
377373 dilation = (dilation , ) * input_dim
378374
375+
379376 kernel = weight .detach ().cpu ().numpy ()
380377
381378 if bias is not None :
382379 bias = bias .detach ().cpu ().numpy ()
383380
384- layer = ctx .network .add_convolution_nd (
385- input = input_trt ,
381+ # Handle reshape 1D to 2D
382+ if input_dim == 1 :
383+ kernel_size = kernel_size + (1 ,)
384+ stride = stride + (1 ,)
385+ padding = padding + (0 ,)
386+ dilation = dilation + (1 ,)
387+ unsqueeze_layer = ctx .network .add_shuffle (input_trt )
388+ set_layer_precision (ctx , unsqueeze_layer )
389+ unsqueeze_layer .reshape_dims = tuple ([0 ]* input .ndim ) + (1 ,)
390+ conv_input = unsqueeze_layer .get_output (0 )
391+ else :
392+ conv_input = input_trt
393+
394+
395+ conv_layer = ctx .network .add_convolution_nd (
396+ input = conv_input ,
386397 num_output_maps = out_channels ,
387398 kernel_shape = kernel_size ,
388399 kernel = kernel ,
389400 bias = bias )
390- layer .stride_nd = stride
391- layer .padding_nd = padding
392- layer .dilation_nd = dilation
401+ conv_layer .stride_nd = stride
402+ conv_layer .padding_nd = padding
403+ conv_layer .dilation_nd = dilation
393404
394405 if groups is not None :
395- layer .num_groups = groups
406+ conv_layer .num_groups = groups
396407
397- output ._trt = layer .get_output (0 )
408+ output ._trt = conv_layer .get_output (0 )
409+
410+ # Handle reshape 2D backt o 1D
411+ if input_dim == 1 :
412+ squeeze_layer = ctx .network .add_shuffle (conv_layer .get_output (0 ))
413+ set_layer_precision (ctx , squeeze_layer )
414+ squeeze_layer .reshape_dims = tuple ([0 ] * input .ndim )
415+ output ._trt = squeeze_layer .get_output (0 )
416+ else :
417+ output ._trt = conv_layer .get_output (0 )
398418
399419
400- @tensorrt_converter ('torch.nn.functional.conv_transpose1d' )
401- def convert_conv_transpose1d (ctx ):
402- raise NotImplementedError
403420
421+ @tensorrt_converter ('torch.nn.functional.conv_transpose1d' )
404422@tensorrt_converter ('torch.nn.functional.conv_transpose2d' )
405423@tensorrt_converter ('torch.nn.functional.conv_transpose3d' )
406424def convert_conv_transpose2d3d (ctx ):
@@ -442,19 +460,43 @@ def convert_conv_transpose2d3d(ctx):
442460
443461 bias = trt .Weights (torch_dtype_to_trt (weight .dtype ))
444462
445- layer = ctx .network .add_deconvolution_nd (
446- input = input_trt ,
463+
464+ # Handle reshape 1D to 2D
465+ if input_dim == 1 :
466+ kernel_size = kernel_size + (1 ,)
467+ stride = stride + (1 ,)
468+ padding = padding + (0 ,)
469+ dilation = dilation + (1 ,)
470+ unsqueeze_layer = ctx .network .add_shuffle (input_trt )
471+ set_layer_precision (ctx , unsqueeze_layer )
472+ unsqueeze_layer .reshape_dims = tuple ([0 ]* input .ndim ) + (1 ,)
473+ conv_input = unsqueeze_layer .get_output (0 )
474+ else :
475+ conv_input = input_trt
476+
477+
478+ conv_layer = ctx .network .add_deconvolution_nd (
479+ input = conv_input ,
447480 num_output_maps = out_channels ,
448481 kernel_shape = kernel_size ,
449482 kernel = kernel ,
450483 bias = bias )
451- layer .stride_nd = stride
452- layer .padding_nd = padding
484+ conv_layer .stride_nd = stride
485+ conv_layer .padding_nd = padding
453486
454487 if groups is not None :
455- layer .num_groups = groups
488+ conv_layer .num_groups = groups
489+
490+
491+ # Handle reshape 2D backt o 1D
492+ if input_dim == 1 :
493+ squeeze_layer = ctx .network .add_shuffle (conv_layer .get_output (0 ))
494+ set_layer_precision (ctx , squeeze_layer )
495+ squeeze_layer .reshape_dims = tuple ([0 ] * input .ndim )
496+ output ._trt = squeeze_layer .get_output (0 )
497+ else :
498+ output ._trt = conv_layer .get_output (0 )
456499
457- output ._trt = layer .get_output (0 )
458500
459501
460502@tensorrt_converter ('torch.div' )
0 commit comments