Skip to content

Commit c717efa

Browse files
committed
more unit test cases
1 parent 39cd99c commit c717efa

File tree

2 files changed

+112
-7
lines changed

2 files changed

+112
-7
lines changed

tests/converter_tests/test_converters.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,3 +598,105 @@ def test_max_pool(nd, kernel_size, stride, padding, dilation, ceil_mode):
598598
inputs = [torch.randn(*input_size).cuda()]
599599
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)
600600

601+
@pytest.mark.parametrize("op", ["min","max", "fmod"])
602+
def test_binary_op_elementwise(op):
603+
if op == "max":
604+
fn = lambda x, y: torch.max(x, y)
605+
elif op == "min":
606+
fn = lambda x, y: torch.min(x, y)
607+
elif op == "fmod":
608+
fn = lambda x, y: torch.fmod(x, y)
609+
610+
611+
module = BinaryModule(fn).cuda().eval()
612+
inputs = [torch.randn(1, 3, 3).cuda(), torch.randn(1, 3, 3).cuda()]
613+
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)
614+
615+
@pytest.mark.parametrize("op", ["min","max","mean","sum"])
616+
def test_reduce_op(op):
617+
if op == "max":
618+
fn = lambda x: torch.max(x)
619+
elif op == "min":
620+
fn = lambda x: torch.min(x)
621+
elif op == "mean":
622+
fn = lambda x: torch.mean(x)
623+
elif op == "sum":
624+
fn = lambda x: torch.sum(x)
625+
626+
module = UnaryModule(fn).cuda().eval()
627+
inputs = [torch.randn(1, 3, 3).cuda()]
628+
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)
629+
630+
@pytest.mark.parametrize("op", ["mul", "__mul__", "__rmul__"])
631+
@pytest.mark.parametrize("scalar", [2, 2.])
632+
def test_mul_scalar(op, scalar):
633+
634+
op_map = {
635+
"mul": lambda x: torch.mul(x, scalar),
636+
"__mul__": lambda x: x * scalar,
637+
"__rmul__": lambda x: scalar * x
638+
}
639+
640+
module = UnaryModule(op_map[op]).cuda().eval()
641+
642+
inputs = [torch.randn(1, 3, 3).cuda()]
643+
644+
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)
645+
646+
@pytest.mark.parametrize("dim,start,length", [
647+
(0, 0, 2),
648+
(1, 1, 2),
649+
(-1, -1, 1)
650+
])
651+
def test_narrow(dim, start, length):
652+
module = UnaryModule(lambda x: torch.narrow(x, dim, start, length)).cuda().eval()
653+
654+
inputs = [torch.randn(3, 3).cuda()]
655+
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)
656+
657+
658+
def test_ne_binary():
659+
module = BinaryModule(lambda x, y: x != y).cuda().eval()
660+
inputs = [torch.zeros(1, 3, 3).cuda()]
661+
inputs.append(inputs[0].clone())
662+
inputs[0][0] = 1
663+
664+
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)
665+
666+
667+
@pytest.mark.parametrize("p", [1, 2])
668+
@pytest.mark.parametrize("dim", [1, 2])
669+
def test_normalize(p, dim):
670+
module = UnaryModule(lambda x: torch.nn.functional.normalize(x, p, dim)).cuda().eval()
671+
inputs = [torch.zeros(1, 3, 3).cuda()]
672+
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)
673+
674+
675+
@pytest.mark.parametrize("pad,mode,value", [
676+
((1, 1), "constant", 0.),
677+
((1, 1, 2, 2), "constant", 0.),
678+
((0, 1, 2, 1, 3, 3), "constant", 0.),
679+
])
680+
def test_pad(pad, mode, value):
681+
module = UnaryModule(
682+
lambda x: torch.nn.functional.pad(
683+
x, pad, mode, value
684+
)
685+
).cuda().eval()
686+
inputs = [torch.randn(3, 3, 4, 2).cuda()]
687+
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)
688+
689+
690+
@pytest.mark.parametrize("permutation", [
691+
(0, 2, 1),
692+
(0, 2, 1, 3)
693+
])
694+
def test_permute(permutation):
695+
696+
module = UnaryModule(
697+
lambda x: x.permute(*permutation)
698+
).cuda().eval()
699+
sizes = [i + 1 for i in range(len(permutation))]
700+
701+
inputs = [torch.randn(*sizes).cuda()]
702+
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)

torch2trt/converters.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,12 +1283,11 @@ def convert_mean(ctx):
12831283
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
12841284
output = ctx.method_return
12851285

1286-
# get dims from args or kwargs
1287-
if 'dim' in ctx.method_kwargs:
1288-
dim = ctx.method_kwargs['dim']
1289-
elif len(ctx.method_args) >= 2:
1290-
dim = ctx.method_args[1]
1291-
1286+
dim = get_arg(ctx, "dim", 1, None)
1287+
1288+
if dim is None:
1289+
dim = [i for i in range(input.ndim)]
1290+
12921291
# convert list to tuple
12931292
if isinstance(dim, list):
12941293
dim = tuple(dim)
@@ -1414,6 +1413,7 @@ def convert_narrow(ctx):
14141413
start = [0]*len(shape)
14151414
stride = [1]*len(shape)
14161415
dim = ctx.method_args[1] if get_arg(ctx, 'dim', pos=1, default=0) >=0 else len(shape)+get_arg(ctx, 'dim', pos=1, default=0)
1416+
14171417
start[dim] = ctx.method_args[2]
14181418
shape[dim] = ctx.method_args[3]
14191419
# not consider batch dimension
@@ -1469,7 +1469,10 @@ def convert_pad(ctx):
14691469
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
14701470
output = ctx.method_return
14711471

1472-
pad = ctx.method_args[1]
1472+
pad = get_arg(ctx, "pad", 1, None)
1473+
mode = get_arg(ctx, "mode", 2, "constant")
1474+
value = get_arg(ctx, "value", 3, 0.)
1475+
14731476
pre_padding = (pad[2], pad[0])
14741477
post_padding = (pad[3], pad[1])
14751478

0 commit comments

Comments
 (0)