|
| 1 | +from torch2trt.torch2trt import * |
| 2 | +from torch2trt.module_test import add_module_test |
| 3 | + |
| 4 | + |
| 5 | +@tensorrt_converter("torch.nn.functional.max_pool3d") |
| 6 | +@tensorrt_converter("torch.max_pool3d") |
| 7 | +def convert_max_pool3d(ctx): |
| 8 | + # parse args |
| 9 | + input = get_arg(ctx, "input", pos=0, default=None) |
| 10 | + kernel_size = get_arg(ctx, "kernel_size", pos=1, default=None) |
| 11 | + stride = get_arg(ctx, "stride", pos=2, default=None) |
| 12 | + padding = get_arg(ctx, "padding", pos=3, default=0) |
| 13 | + dilation = get_arg(ctx, "dilation", pos=4, default=1) |
| 14 | + ceil_mode = get_arg(ctx, "ceil_mode", pos=5, default=False) |
| 15 | + |
| 16 | + # get input trt tensor (or create constant if it doesn't exist) |
| 17 | + input_trt = add_missing_trt_tensors(ctx.network, [input])[0] |
| 18 | + |
| 19 | + output = ctx.method_return |
| 20 | + |
| 21 | + # get kernel size |
| 22 | + if not isinstance(kernel_size, tuple): |
| 23 | + kernel_size = (kernel_size,) * 3 |
| 24 | + |
| 25 | + # get stride |
| 26 | + if not isinstance(stride, tuple): |
| 27 | + stride = (stride,) * 3 |
| 28 | + |
| 29 | + # get padding |
| 30 | + if not isinstance(padding, tuple): |
| 31 | + padding = (padding,) * 3 |
| 32 | + |
| 33 | + layer = ctx.network.add_pooling_nd( |
| 34 | + input=input_trt, type=trt.PoolingType.MAX, window_size=kernel_size |
| 35 | + ) |
| 36 | + |
| 37 | + layer.stride_nd = stride |
| 38 | + layer.padding_nd = padding |
| 39 | + |
| 40 | + if ceil_mode: |
| 41 | + layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP |
| 42 | + |
| 43 | + output._trt = layer.get_output(0) |
| 44 | + |
| 45 | + |
| 46 | +@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 4, 6, 7)]) |
| 47 | +@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 5, 7, 8)]) |
| 48 | +def test_MaxPool3d_without_ceil_mode(): |
| 49 | + return torch.nn.MaxPool3d(kernel_size=3, stride=2, padding=1, ceil_mode=False) |
| 50 | + |
| 51 | + |
| 52 | +@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 4, 6, 7)]) |
| 53 | +@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 5, 7, 8)]) |
| 54 | +def test_MaxPool3d_with_ceil_mode(): |
| 55 | + return torch.nn.MaxPool3d(kernel_size=3, stride=2, padding=1, ceil_mode=True) |
0 commit comments