Skip to content

Commit b0cc8e7

Browse files
authored
Assert that inputs are contiguous (#418)
* Assert that inputs are contiguous * Turn non-contiguous tensors into contiguous * Add unit test * Fix tabs
1 parent ea835dc commit b0cc8e7

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

torch2trt/tests/test_contiguous.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import torch
2+
from torch2trt import torch2trt
3+
4+
5+
def test_contiguous():
6+
net = torch.nn.Conv2d(3, 10, kernel_size=3)
7+
net.eval().cuda()
8+
9+
test_tensor = torch.randn((1, 25, 25, 3)).cuda().permute((0, 3, 1, 2))
10+
11+
with torch.no_grad():
12+
test_out = net(test_tensor)
13+
14+
with torch.no_grad():
15+
trt_net = torch2trt(net, [test_tensor])
16+
test_trt_out = trt_net(test_tensor)
17+
18+
delta = (test_out.contiguous() - test_trt_out.contiguous()).abs().sum()
19+
assert delta < 1e-3, f"Delta: {delta}"
20+

torch2trt/torch2trt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ def forward(self, *inputs):
427427

428428
for i, input_name in enumerate(self.input_names):
429429
idx = self.engine.get_binding_index(input_name)
430-
bindings[idx] = inputs[i].data_ptr()
430+
bindings[idx] = inputs[i].contiguous().data_ptr()
431431

432432
self.context.execute_async(
433433
batch_size, bindings, torch.cuda.current_stream().cuda_stream

0 commit comments

Comments
 (0)