Skip to content

Commit dedb240

Browse files
committed
many trt 10 fixes
1 parent 8aec25b commit dedb240

File tree

6 files changed

+498
-371
lines changed

6 files changed

+498
-371
lines changed

tests/converter_tests/test_converters.py

Lines changed: 28 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def forward(self, x, y):
160160
return x
161161

162162
module = IAdd().cuda().eval()
163-
inputs = [torch.randn(1, 3, 4).cuda(), torch.randn(1, 3, 4).cuda()]
163+
inputs = [torch.ones(1, 3, 4).cuda(), torch.ones(1, 3, 4).cuda()]
164164
cross_validate(module, inputs, fp16_mode=False, tol=1e-2)
165165

166166

@@ -179,29 +179,6 @@ def test_radd_float():
179179
# TODO: radd, add, iadd
180180

181181

182-
@pytest.mark.parametrize("kernel_size,stride,padding,ceil_mode,count_include_pad", [
183-
(3, 2, 1, False, True),
184-
(3, 2, 1, True, False)
185-
])
186-
def test_avg_pool2d(kernel_size, stride, padding, ceil_mode, count_include_pad):
187-
module = UnaryModule(lambda x: torch.nn.functional.avg_pool2d(
188-
x, kernel_size, stride, padding, ceil_mode, count_include_pad
189-
)).cuda().eval()
190-
inputs = [torch.randn(1, 3, 8, 8).cuda()]
191-
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)
192-
193-
194-
@pytest.mark.parametrize("kernel_size,stride,padding,ceil_mode,count_include_pad", [
195-
(3, 2, 1, False, True),
196-
(3, 2, 1, True, False)
197-
])
198-
def test_avg_pool3d(kernel_size, stride, padding, ceil_mode, count_include_pad):
199-
module = UnaryModule(lambda x: torch.nn.functional.avg_pool3d(
200-
x, kernel_size, stride, padding, ceil_mode, count_include_pad
201-
)).cuda().eval()
202-
inputs = [torch.randn(1, 3, 8, 8, 8).cuda()]
203-
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)
204-
205182

206183
def test_batch_norm_1d():
207184
module = nn.BatchNorm1d(3).cuda().eval()
@@ -413,7 +390,7 @@ def fn(x):
413390
x /= val
414391
return x
415392
module = UnaryModule(fn).cuda().eval()
416-
inputs = [torch.randn(1, 4, 4).cuda()]
393+
inputs = [torch.ones(1, 4, 4).cuda()]
417394
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)
418395

419396

@@ -455,10 +432,11 @@ def test_flatten(start_dim, end_dim):
455432
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)
456433

457434

458-
def test_floordiv():
435+
@pytest.mark.parametrize("denom", [1., 2.])
436+
def test_floordiv(denom):
459437
module = BinaryModule(lambda x, y: x // y).cuda().eval()
460-
inputs = [torch.randn(1, 2, 3, 4, 5).cuda()]
461-
inputs.append(torch.ones_like(inputs[0])*2)
438+
inputs = [torch.ones(1, 2, 3, 4, 5).cuda()]
439+
inputs.append(torch.ones_like(inputs[0]) * denom)
462440
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)
463441

464442

@@ -578,15 +556,13 @@ def test_matmul(shape_a, shape_b):
578556
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)
579557

580558

581-
@pytest.mark.parametrize("nd", [1,2,3])
582559
@pytest.mark.parametrize(
583560
"kernel_size,stride,padding,dilation,ceil_mode", [
584561
(3, 2, 1, 1, False),
585-
(3, 2, 1, 1, False),
586-
(3, 2, 1, 1, False),
587562
]
588563
)
589-
def test_max_pool(nd, kernel_size, stride, padding, dilation, ceil_mode):
564+
@pytest.mark.parametrize("nd", [1,2,3])
565+
def test_max_pool_nd(nd, kernel_size, stride, padding, dilation, ceil_mode):
590566
if nd == 1:
591567
cls = nn.MaxPool1d
592568
elif nd == 2:
@@ -598,6 +574,26 @@ def test_max_pool(nd, kernel_size, stride, padding, dilation, ceil_mode):
598574
inputs = [torch.randn(*input_size).cuda()]
599575
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)
600576

577+
578+
@pytest.mark.parametrize(
579+
"kernel_size,stride,padding,ceil_mode,count_include_pad", [
580+
(3, 2, 1, False, False),
581+
]
582+
)
583+
@pytest.mark.parametrize("nd", [1,2,3])
584+
def test_avg_pool_nd(nd, kernel_size, stride, padding, ceil_mode, count_include_pad):
585+
if nd == 1:
586+
cls = nn.AvgPool1d
587+
elif nd == 2:
588+
cls = nn.AvgPool2d
589+
elif nd == 3:
590+
cls = nn.AvgPool3d
591+
module = cls(kernel_size,stride,padding,ceil_mode=ceil_mode, count_include_pad=count_include_pad).cuda().eval()
592+
input_size = [1, 3] + [4]*nd
593+
inputs = [torch.randn(*input_size).cuda()]
594+
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)
595+
596+
601597
@pytest.mark.parametrize("op", ["min","max", "fmod"])
602598
def test_binary_op_elementwise(op):
603599
if op == "max":

0 commit comments

Comments
 (0)