Skip to content

Commit fb52a66

Browse files
authored
Merge pull request #691 from jaybdub/einsum_converter_native
Explicit batch and Einsum converter
2 parents 2732b35 + cf229c9 commit fb52a66

38 files changed

+154
-79
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,6 @@ __pycache__/
1616
*.ipynb_checkpoints
1717
*.pth
1818
docs/converters.md
19-
site
19+
site
20+
ToJetsonGrp
21+
.vscode

torch2trt/converters/BatchNorm1d.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,21 @@ def convert_BatchNorm2d(ctx):
1717
layer = ctx.network.add_shuffle(input_trt)
1818

1919
if len(input.shape) == 2:
20-
layer.reshape_dims = (input.shape[1], 1, 1)
20+
layer.reshape_dims = (input.shape[0], input.shape[1], 1, 1)
2121
else:
22-
layer.reshape_dims = (input.shape[1], input.shape[2], 1)
22+
layer.reshape_dims = (input.shape[0], input.shape[1], input.shape[2], 1)
2323

2424
layer = ctx.network.add_scale(layer.get_output(0), trt.ScaleMode.CHANNEL, bias, scale, power)
2525

2626
# reshape back to 1D
2727
layer = ctx.network.add_shuffle(layer.get_output(0))
28-
layer.reshape_dims = tuple(output.shape[1:])
28+
layer.reshape_dims = tuple(output.shape)
2929

3030
output._trt = layer.get_output(0)
3131

3232

3333
@add_module_test(torch.float32, torch.device('cuda'), [(1, 10)])
3434
@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 3)])
35+
@add_module_test(torch.float32, torch.device('cuda'), [(2, 10, 3)], max_batch_size=2)
3536
def test_BatchNorm1d_basic():
3637
return torch.nn.BatchNorm1d(10)

torch2trt/converters/Conv1d.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def convert_Conv1d(ctx):
2222

2323
# reshape to 2D
2424
layer = ctx.network.add_shuffle(input_trt)
25-
layer.reshape_dims = (-1, input.shape[-1], 1)
25+
layer.reshape_dims = (input.shape[0], -1, input.shape[-1], 1)
2626

2727
layer = ctx.network.add_convolution(
2828
input=layer.get_output(0),
@@ -39,26 +39,30 @@ def convert_Conv1d(ctx):
3939

4040
# reshape back to 1D
4141
layer = ctx.network.add_shuffle(layer.get_output(0))
42-
layer.reshape_dims = (-1, output.shape[-1])
42+
layer.reshape_dims = (input.shape[0], -1, output.shape[-1])
4343

4444
output._trt = layer.get_output(0)
4545

4646

4747
@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224)])
48+
@add_module_test(torch.float32, torch.device('cuda'), [(2, 10, 224)], max_batch_size=2)
4849
def test_Conv1d_basic():
4950
return torch.nn.Conv1d(10, 5, kernel_size=1, stride=1, padding=0)
5051

5152

5253
@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224)])
54+
@add_module_test(torch.float32, torch.device('cuda'), [(2, 10, 224)], max_batch_size=2)
5355
def test_Conv1d_stride2():
5456
return torch.nn.Conv1d(10, 5, kernel_size=1, stride=2, padding=0)
5557

5658

5759
@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224)])
60+
@add_module_test(torch.float32, torch.device('cuda'), [(2, 10, 224)], max_batch_size=2)
5861
def test_Conv1d_kernel3():
5962
return torch.nn.Conv1d(10, 5, kernel_size=3, stride=2, padding=1)
6063

6164

6265
@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224)])
66+
@add_module_test(torch.float32, torch.device('cuda'), [(2, 10, 224)], max_batch_size=2)
6367
def test_Conv1d_dilation2():
6468
return torch.nn.Conv1d(10, 5, kernel_size=3, stride=1, padding=1, dilation=2)

torch2trt/converters/Linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def convert_Linear(ctx):
2727

2828
# reshape back to N
2929
layer = ctx.network.add_shuffle(layer.get_output(0))
30-
layer.reshape_dims = tuple(output.shape[1:])
30+
layer.reshape_dims = tuple(output.shape)
3131

3232
output._trt = layer.get_output(0)
3333

torch2trt/converters/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from .clamp import *
3030
from .compare import *
3131
from .div import *
32+
from .einsum import *
3233
from .expand import *
3334
from .floordiv import *
3435
from .gelu import *

torch2trt/converters/add.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def convert_add(ctx):
1111
input_b = ctx.method_args[1]
1212
output = ctx.method_return
1313
input_a_trt, input_b_trt = add_missing_trt_tensors(ctx.network, [input_a, input_b])
14-
input_a_trt, input_b_trt = broadcast_trt_tensors(ctx.network, [input_a_trt, input_b_trt], len(output.shape) - 1)
14+
input_a_trt, input_b_trt = broadcast_trt_tensors(ctx.network, [input_a_trt, input_b_trt], len(output.shape))
1515
layer = ctx.network.add_elementwise(input_a_trt, input_b_trt, trt.ElementWiseOperation.SUM)
1616
output._trt = layer.get_output(0)
1717

torch2trt/converters/batch_norm.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,26 @@ def convert_batch_norm_trt7(ctx):
1818
scale = weight.detach().cpu().numpy() / np.sqrt(running_var.detach().cpu().numpy() + eps)
1919
bias = bias.detach().cpu().numpy() - running_mean.detach().cpu().numpy() * scale
2020
power = np.ones_like(scale)
21-
22-
layer = ctx.network.add_scale_nd(input_trt, trt.ScaleMode.CHANNEL, bias, scale, power, 0)
21+
22+
layer = ctx.network.add_scale_nd(input_trt, trt.ScaleMode.CHANNEL, bias, scale, power, 1)
2323
output._trt = layer.get_output(0)
2424

2525

2626

2727
@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 3, 3)], enabled=trt_version() >= '7.0')
28+
@add_module_test(torch.float32, torch.device('cuda'), [(2, 10, 3, 3)], enabled=trt_version() >= '7.0', max_batch_size=2)
2829
def test_batch_norm_2d_trt7():
2930
return torch.nn.BatchNorm2d(10)
3031

3132

3233
@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 3, 3, 3)], enabled=trt_version() >= '7.0')
34+
@add_module_test(torch.float32, torch.device('cuda'), [(2, 10, 3, 3, 3)], enabled=trt_version() >= '7.0', max_batch_size=2)
3335
def test_batch_norm_3d_2_trt7():
3436
return torch.nn.BatchNorm3d(10)
3537

3638

3739
@add_module_test(torch.float32, torch.device('cuda'), [(1, 32, 2, 36, 47)], enabled=trt_version() >= '7.0')
40+
@add_module_test(torch.float32, torch.device('cuda'), [(2, 32, 2, 36, 47)], enabled=trt_version() >= '7.0', max_batch_size=2)
3841
def test_batch_norm_3d_trt7():
3942
return torch.nn.BatchNorm3d(32)
4043

torch2trt/converters/cat.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ def convert_cat(ctx):
1313

1414
output = ctx.method_return
1515
trt_inputs = add_missing_trt_tensors(ctx.network, inputs)
16-
trt_inputs = broadcast_trt_tensors(ctx.network, trt_inputs, len(output.shape) - 1)
16+
trt_inputs = broadcast_trt_tensors(ctx.network, trt_inputs, len(output.shape))
1717

1818
layer = ctx.network.add_concatenation(inputs=trt_inputs)
19-
layer.axis = dim - 1
19+
layer.axis = dim
2020
output._trt = layer.get_output(0)
2121

2222

@@ -30,15 +30,18 @@ def forward(self, *x):
3030

3131

3232
@add_module_test(torch.float32, torch.device('cuda'), [(1, 4, 4), (1, 3, 4), (1, 17, 4)])
33+
@add_module_test(torch.float32, torch.device('cuda'), [(2, 4, 4), (2, 3, 4), (2, 17, 4)], max_batch_size=2)
3334
def test_Cat_basic():
3435
return Cat(1)
3536

3637

3738
@add_module_test(torch.float32, torch.device('cuda'), [(1, 4, 4), (1, 4, 4), (1, 4, 4)])
39+
@add_module_test(torch.float32, torch.device('cuda'), [(2, 4, 4), (2, 4, 4), (2, 4, 4)], max_batch_size=2)
3840
def test_Cat_neg1_dim():
3941
return Cat(-1)
4042

4143

4244
@add_module_test(torch.float32, torch.device('cuda'), [(1, 4, 4), (1, 4, 4), (1, 4, 4)])
45+
@add_module_test(torch.float32, torch.device('cuda'), [(2, 4, 4), (2, 4, 4), (2, 4, 4)], max_batch_size=2)
4346
def test_Cat_neg2_dim():
4447
return Cat(-2)

torch2trt/converters/chunk.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,28 +33,33 @@ def forward(self, x):
3333

3434
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)])
3535
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)])
36+
@add_module_test(torch.float32, torch.device('cuda'), [(2, 3, 3, 3)], max_batch_size=2)
3637
def test_torch_chunk_1_1():
3738
return TorchChunk(1, 1)
3839

3940

4041
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)])
4142
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)])
43+
@add_module_test(torch.float32, torch.device('cuda'), [(2, 3, 3, 3)], max_batch_size=2)
4244
def test_torch_chunk_2_1():
4345
return TorchChunk(2, 1)
4446

4547

4648
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)])
4749
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)])
50+
@add_module_test(torch.float32, torch.device('cuda'), [(2, 3, 3, 3)], max_batch_size=2)
4851
def test_torch_chunk_3_1():
4952
return TorchChunk(3, 1)
5053

5154

5255
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)])
5356
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)])
57+
@add_module_test(torch.float32, torch.device('cuda'), [(2, 3, 3, 3)], max_batch_size=2)
5458
def test_torch_chunk_3_2():
5559
return TorchChunk(3, 2)
5660

5761

5862
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)])
63+
@add_module_test(torch.float32, torch.device('cuda'), [(2, 3, 3, 3)], max_batch_size=2)
5964
def test_tensor_chunk_3_2():
6065
return TensorChunk(3, 2)

torch2trt/converters/div.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def convert_div(ctx):
1212
input_b = ctx.method_args[1]
1313
output = ctx.method_return
1414
input_a_trt, input_b_trt = add_missing_trt_tensors(ctx.network, [input_a, input_b])
15-
input_a_trt, input_b_trt = broadcast_trt_tensors(ctx.network, [input_a_trt, input_b_trt], len(output.shape) - 1)
15+
input_a_trt, input_b_trt = broadcast_trt_tensors(ctx.network, [input_a_trt, input_b_trt], len(output.shape))
1616
layer = ctx.network.add_elementwise(input_a_trt, input_b_trt, trt.ElementWiseOperation.DIV)
1717
output._trt = layer.get_output(0)
1818

@@ -24,7 +24,7 @@ def convert_rdiv(ctx):
2424
input_b = ctx.method_args[0]
2525
output = ctx.method_return
2626
input_a_trt, input_b_trt = add_missing_trt_tensors(ctx.network, [input_a, input_b])
27-
input_a_trt, input_b_trt = broadcast_trt_tensors(ctx.network, [input_a_trt, input_b_trt], len(output.shape) - 1)
27+
input_a_trt, input_b_trt = broadcast_trt_tensors(ctx.network, [input_a_trt, input_b_trt], len(output.shape))
2828
layer = ctx.network.add_elementwise(input_a_trt, input_b_trt, trt.ElementWiseOperation.DIV)
2929
output._trt = layer.get_output(0)
3030

0 commit comments

Comments
 (0)