Skip to content

Commit cbefcbb

Browse files
authored
feat: support native_dropout dynamo converter (#2931)
1 parent f411b6f commit cbefcbb

File tree

3 files changed

+106
-0
lines changed

3 files changed

+106
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3243,3 +3243,36 @@ def aten_ops_index_select(
32433243
args[1],
32443244
args[2],
32453245
)
3246+
3247+
3248+
def dropout_inference_validator(node: Node) -> bool:
3249+
train_mode = args_bounds_check(node.args, 2, None)
3250+
if train_mode is False:
3251+
return True
3252+
else: # train_mode is True or None
3253+
_LOGGER.debug(
3254+
"Currently only inference mode is supported for dropout operation."
3255+
)
3256+
return False
3257+
3258+
3259+
@dynamo_tensorrt_converter(
3260+
torch.ops.aten.native_dropout.default,
3261+
capability_validator=dropout_inference_validator,
3262+
)
3263+
def aten_ops_native_dropout(
3264+
ctx: ConversionContext,
3265+
target: Target,
3266+
args: Tuple[Argument, ...],
3267+
kwargs: Dict[str, Argument],
3268+
name: str,
3269+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3270+
return impl.unary.native_dropout(
3271+
ctx,
3272+
target,
3273+
SourceIR.ATEN,
3274+
name,
3275+
args[0],
3276+
args[1],
3277+
args_bounds_check(args, 2, None),
3278+
)

py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,3 +571,20 @@ def isnan(
571571
)
572572

573573
return nan_values_mask
574+
575+
576+
def native_dropout(
577+
ctx: ConversionContext,
578+
target: Target,
579+
source_ir: Optional[SourceIR],
580+
name: str,
581+
input_val: Union[TRTTensor, torch.Tensor, np.ndarray],
582+
p: float,
583+
train: Optional[bool] = False,
584+
) -> TRTTensor:
585+
if train is False:
586+
identity_layer = ctx.net.add_identity(input_val)
587+
set_layer_name(identity_layer, target, f"{name}_input", source_ir)
588+
mask = np.ones(input_val.shape, dtype=bool)
589+
mask = get_trt_tensor(ctx, mask, f"{name}_mask")
590+
return identity_layer.get_output(0), mask
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestDropOutConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
((10,), 0, False),
13+
((1, 3), 0.3, False),
14+
((2, 2, 2), 0.5),
15+
((2, 2, 2, 2), 1),
16+
]
17+
)
18+
def test_native_dropout(self, input_shape, p, train=False):
19+
class NativeDropout(nn.Module):
20+
def forward(self, input):
21+
return torch.ops.aten.native_dropout.default(input, p, train)
22+
23+
inputs = [torch.randn(input_shape)]
24+
self.run_test(
25+
NativeDropout(),
26+
inputs,
27+
)
28+
29+
@parameterized.expand(
30+
[
31+
(
32+
torch.randn(
33+
10,
34+
),
35+
0,
36+
False,
37+
),
38+
(torch.randn(1, 3), 0.3, False),
39+
(torch.randn(2, 2, 2), 0.5),
40+
(torch.randn(2, 2, 2, 2), 1),
41+
]
42+
)
43+
def test_native_dropout_pytorch(self, input, p, train=False):
44+
class NativeDropout(nn.Module):
45+
def forward(self):
46+
return torch.ops.aten.native_dropout.default(input, p, train)
47+
48+
inputs = []
49+
self.run_test(
50+
NativeDropout(),
51+
inputs,
52+
)
53+
54+
55+
if __name__ == "__main__":
56+
run_tests()

0 commit comments

Comments
 (0)