Skip to content

Commit b5262ae

Browse files
Shirong WuWei Wei
authored andcommitted
Enable split explicit batch dim operator (#77)
Summary: Pull Request resolved: https://github.com/pytorch/fx2trt/pull/77 ATT Reviewed By: brad-mengchi, 842974287 Differential Revision: D36458976 fbshipit-source-id: 4bbd4025547244ba03a44af8ea2e96c4b38896a5
1 parent e61fd1a commit b5262ae

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

fx/converters/acc_ops_converters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2578,7 +2578,7 @@ def acc_ops_masked_fill_tensor(
25782578
return layer.get_output(0)
25792579

25802580

2581-
@tensorrt_converter(acc_ops.split, no_explicit_batch_dim=True)
2581+
@tensorrt_converter(acc_ops.split)
25822582
def acc_ops_split(
25832583
network: TRTNetwork,
25842584
target: Target,

test/trt_lower/trt_operator_supported_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_unsupport_node_explicit_batch_dim(self):
3131
class TestModule(nn.Module):
3232
def forward(self, x):
3333
y = torch.add(input=x, other=x)
34-
return torch.split(y, 2)
34+
return torch.max_pool1d(y, 1)
3535

3636
mod = TestModule()
3737
traced_mod = acc_tracer.trace(mod, [torch.randn(5, 2)])

0 commit comments

Comments
 (0)