Skip to content

Commit 48eb79e

Browse files
authored
Merge pull request #397 from NVIDIA-AI-IOT/const_dim_fix
Const dim fix
2 parents 63895f0 + a08e6dc commit 48eb79e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+262
-67
lines changed

torch2trt/converters/AdaptiveAvgPool2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def convert_AdaptiveAvgPool2d(ctx):
88
input = ctx.method_args[1]
99
output = ctx.method_return
1010

11-
input_trt = trt_(ctx.network, input)
11+
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
1212

1313
output_size = module.output_size
1414
if not isinstance(output_size, tuple):

torch2trt/converters/BatchNorm1d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
def convert_BatchNorm2d(ctx):
77
module = ctx.method_args[0]
88
input = ctx.method_args[1]
9-
input_trt = trt_(ctx.network, input)
9+
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
1010
output = ctx.method_return
1111

1212
scale = module.weight.detach().cpu().numpy() / np.sqrt(module.running_var.detach().cpu().numpy() + module.eps)

torch2trt/converters/BatchNorm2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
def convert_BatchNorm2d(ctx):
77
module = ctx.method_args[0]
88
input = ctx.method_args[1]
9-
input_trt = trt_(ctx.network, input)
9+
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
1010
output = ctx.method_return
1111

1212
scale = module.weight.detach().cpu().numpy() / np.sqrt(

torch2trt/converters/Conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
def convert_Conv_trt7(ctx):
88
module = ctx.method_args[0]
99
input = ctx.method_args[1]
10-
input_trt = trt_(ctx.network, input)
10+
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
1111
output = ctx.method_return
1212

1313
input_dim = input.dim() - 2

torch2trt/converters/Conv1d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
def convert_Conv1d(ctx):
77
module = ctx.method_args[0]
88
input = ctx.method_args[1]
9-
input_trt = trt_(ctx.network, input)
9+
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
1010
output = ctx.method_return
1111

1212
kernel_size = (module.kernel_size[0], 1)

torch2trt/converters/Conv2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
def convert_Conv2d(ctx):
77
module = ctx.method_args[0]
88
input = ctx.method_args[1]
9-
input_trt = trt_(ctx.network, input)
9+
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
1010
output = ctx.method_return
1111

1212
kernel_size = module.kernel_size

torch2trt/converters/ConvTranspose.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
def convert_ConvTranspose2d_trt7(ctx):
88
module = ctx.method_args[0]
99
input = ctx.method_args[1]
10-
input_trt = trt_(ctx.network, input)
10+
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
1111
output = ctx.method_return
1212

1313
input_dim = input.dim() - 2

torch2trt/converters/ConvTranspose2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
def convert_ConvTranspose2d(ctx):
66
module = ctx.method_args[0]
77
input = ctx.method_args[1]
8-
input_trt = trt_(ctx.network, input)
8+
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
99
output = ctx.method_return
1010

1111
kernel_size = module.kernel_size

torch2trt/converters/Identity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@
66
@tensorrt_converter('torch.nn.Dropout3d.forward')
77
def convert_Identity(ctx):
88
input = ctx.method_args[1]
9-
input_trt = trt_(ctx.network, input)
9+
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
1010
output = ctx.method_return
1111
output._trt = input_trt

torch2trt/converters/Linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
def convert_Linear(ctx):
77
module = ctx.method_args[0]
88
input = ctx.method_args[1]
9-
input_trt = trt_(ctx.network, input)
9+
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
1010
output = ctx.method_return
1111

1212
# reshape to ...xNx1x1

0 commit comments

Comments
 (0)