Skip to content

Commit 15fe819

Browse files
committed
update tests
1 parent c717efa commit 15fe819

File tree

2 files changed

+121
-16
lines changed

2 files changed

+121
-16
lines changed

tests/converter_tests/test_converters.py

Lines changed: 114 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -612,21 +612,90 @@ def test_binary_op_elementwise(op):
612612
inputs = [torch.randn(1, 3, 3).cuda(), torch.randn(1, 3, 3).cuda()]
613613
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)
614614

615-
@pytest.mark.parametrize("op", ["min","max","mean","sum"])
616-
def test_reduce_op(op):
617-
if op == "max":
618-
fn = lambda x: torch.max(x)
619-
elif op == "min":
620-
fn = lambda x: torch.min(x)
621-
elif op == "mean":
622-
fn = lambda x: torch.mean(x)
623-
elif op == "sum":
624-
fn = lambda x: torch.sum(x)
625-
626-
module = UnaryModule(fn).cuda().eval()
615+
unary_1d_randn_ops = {
616+
"torch.max": lambda x: torch.max(x),
617+
"torch.Tensor.max": lambda x: x.max(),
618+
"torch.min": lambda x: torch.min(x),
619+
"torch.Tensor.min": lambda x: x.min(),
620+
"torch.mean": lambda x: torch.mean(x),
621+
"torch.Tensor.mean": lambda x: x.mean(),
622+
"torch.sum": lambda x: torch.sum(x),
623+
"torch.Tensor.sum": lambda x: x.sum(),
624+
"torch.prod": lambda x: torch.prod(x),
625+
"torch.Tensor.prod": lambda x: x.prod(),
626+
"torch.relu": lambda x: torch.relu(x),
627+
"torch.nn.functional.relu": lambda x: torch.nn.functional.relu(x),
628+
"torch.Tensor.relu": lambda x: x.relu(),
629+
"torch.nn.functional.relu6": lambda x: torch.nn.functional.relu6(x),
630+
"torch.sigmoid": lambda x: torch.sigmoid(x),
631+
"torch.nn.functional.sigmoid": lambda x: torch.nn.functional.sigmoid(x),
632+
"torch.Tensor.sigmoid": lambda x: x.sigmoid(),
633+
"torch.nn.functional.silu": lambda x: torch.nn.functional.silu(x),
634+
"torch.Tensor.softmax": lambda x: x.softmax(1),
635+
"torch.nn.functional.softmax": lambda x: torch.nn.functional.softmax(x, 1),
636+
"torch.Tensor.squeeze": lambda x: x.squeeze(),
637+
"torch.squeeze": lambda x: torch.squeeze(x),
638+
"torch.stack": lambda x: torch.stack([x, x], dim=1),
639+
"torch.sub": lambda x: torch.sub(x, x),
640+
"torch.Tensor.__sub__": lambda x: x - x,
641+
"torch.Tensor.__rsub__[int]": lambda x: 1 - x,
642+
"torch.Tensor.__rsub__[float]": lambda x: 1.0 - x,
643+
"torch.tanh": lambda x: torch.tanh(x),
644+
"torch.nn.functional.tanh": lambda x: torch.nn.functional.tanh(x),
645+
"torch.tensor": lambda x: torch.tensor(x),
646+
"torch.Tensor.transpose": lambda x: x.transpose(1, 2),
647+
"torch.transpose": lambda x: torch.transpose(x, 1, 2),
648+
"torch.exp": lambda x: torch.exp(x),
649+
"torch.Tensor.exp": lambda x: x.exp(),
650+
"torch.abs": lambda x: torch.abs(x),
651+
"torch.Tensor.abs": lambda x: x.abs(),
652+
"torch.neg": lambda x: torch.neg(x),
653+
"torch.Tensor.neg": lambda x: -x,
654+
"torch.sin": lambda x: torch.sin(x),
655+
"torch.Tensor.sin": lambda x: x.sin(),
656+
"torch.cos": lambda x: torch.cos(x),
657+
"torch.Tensor.cos": lambda x: x.cos(),
658+
"torch.sinh": lambda x: torch.sinh(x),
659+
"torch.Tensor.sinh": lambda x: x.sinh(),
660+
"torch.cosh": lambda x: torch.cosh(x),
661+
"torch.Tensor.cosh": lambda x: x.cosh(),
662+
"torch.atan": lambda x: torch.atan(x),
663+
"torch.Tensor.atan": lambda x: x.atan(),
664+
"torch.ceil": lambda x: torch.ceil(x),
665+
"torch.Tensor.ceil": lambda x: x.ceil(),
666+
"torch.floor": lambda x: torch.floor(x),
667+
"torch.Tensor.floor": lambda x: x.floor()
668+
}
669+
670+
@pytest.mark.parametrize("op", unary_1d_randn_ops.keys())
671+
def test_unary_1d_randn(op):
672+
module = UnaryModule(unary_1d_randn_ops[op]).cuda().eval()
627673
inputs = [torch.randn(1, 3, 3).cuda()]
628674
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)
629675

676+
677+
unary_1d_positive_ops = {
678+
"torch.log": lambda x: torch.log(x),
679+
"torch.Tensor.log": lambda x: x.log(),
680+
"torch.sqrt": lambda x: torch.sqrt(x),
681+
"torch.Tensor.sqrt": lambda x: x.sqrt(),
682+
"torch.reciprocal": lambda x: torch.reciprocal(x),
683+
"torch.Tensor.reciprocal": lambda x: x.reciprocal(),
684+
"torch.tan": lambda x: torch.tan(x),
685+
"torch.Tensor.tan": lambda x: x.tan(),
686+
"torch.asin": lambda x: torch.asin(x),
687+
"torch.Tensor.asin": lambda x: x.asin(),
688+
"torch.acos": lambda x: torch.acos(x),
689+
"torch.Tensor.acos": lambda x: x.acos(),
690+
}
691+
692+
@pytest.mark.parametrize("op", unary_1d_positive_ops.keys())
693+
def test_unary_1d_ones(op):
694+
module = UnaryModule(unary_1d_positive_ops[op]).cuda().eval()
695+
inputs = [0.5 * torch.ones(1, 3, 3).cuda()]
696+
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)
697+
698+
630699
@pytest.mark.parametrize("op", ["mul", "__mul__", "__rmul__"])
631700
@pytest.mark.parametrize("scalar", [2, 2.])
632701
def test_mul_scalar(op, scalar):
@@ -699,4 +768,36 @@ def test_permute(permutation):
699768
sizes = [i + 1 for i in range(len(permutation))]
700769

701770
inputs = [torch.randn(*sizes).cuda()]
702-
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)
771+
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)
772+
773+
@pytest.mark.parametrize("op", [
774+
"torch.pow",
775+
"torch.Tensor.__ipow__",
776+
"torch.Tensor.__pow__",
777+
"torch.Tensor.__rpow__"
778+
])
779+
@pytest.mark.parametrize("scalar", [2, 2.])
780+
def test_scalar_op(op, scalar):
781+
if op == "torch.pow":
782+
fn = lambda x: torch.pow(x, scalar)
783+
elif op == "torch.Tensor.__ipow__":
784+
def ipow(x):
785+
x **= scalar
786+
return x
787+
fn = ipow
788+
elif op == "torch.Tensor.__pow__":
789+
fn = lambda x: x ** scalar
790+
elif op == "torch.Tensor.__rpow__":
791+
fn = lambda x: scalar ** x
792+
793+
794+
module = UnaryModule(fn).cuda().eval()
795+
inputs = [torch.randn(1, 2).cuda()]
796+
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)
797+
798+
799+
def test_prelu():
800+
module = nn.PReLU(4).cuda().eval()
801+
inputs = [torch.randn(1, 4, 3, 3).cuda()]
802+
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)
803+

torch2trt/converters.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1701,9 +1701,13 @@ def convert_squeeze(ctx):
17011701
output = ctx.method_return
17021702
dim = get_arg(ctx, 'dim', pos=1, default=None)
17031703

1704-
if dim < 0:
1705-
dim = len(input.shape) + dim
1706-
assert dim >= 0
1704+
if dim is None:
1705+
dim = tuple([i for i in range(input.ndim)])
1706+
1707+
dim = torch_dim_resolve_negative(dim, input.ndim)
1708+
# if dim < 0:
1709+
# dim = len(input.shape) + dim
1710+
# assert dim >= 0
17071711

17081712
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
17091713

0 commit comments

Comments
 (0)