diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 4482a00f79..6e2dcb6ffc 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -881,8 +881,30 @@ def aten_ops_select( ) +def index_put_indices_continuity_validator( + node: Node, settings: Optional[CompilationSettings] = None +) -> bool: + idxs = node.args[1] # this is a list of indices + present_indices = [idx is not None for idx in idxs] + if len(present_indices) == 0: + return True + first_present_index = next((i for i, v in enumerate(present_indices) if v), None) + if first_present_index is None: + return False + rev_index = next((i for i, v in enumerate(reversed(present_indices)) if v), None) + if rev_index is None: + return False + last_present_index = len(present_indices) - 1 - rev_index + for i in range(first_present_index, last_present_index + 1): + if not present_indices[i]: + return False + return True + + @dynamo_tensorrt_converter( torch.ops.aten.index_put.default, + capability_validator=index_put_indices_continuity_validator, + supports_dynamic_shapes=True, ) @enforce_tensor_types( { diff --git a/tests/py/dynamo/conversion/test_index_put_aten.py b/tests/py/dynamo/conversion/test_index_put_aten.py index 74e38cd0c5..db2c5c09d6 100644 --- a/tests/py/dynamo/conversion/test_index_put_aten.py +++ b/tests/py/dynamo/conversion/test_index_put_aten.py @@ -1,6 +1,7 @@ import torch from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -245,5 +246,193 @@ def forward(self, source_tensor, value_tensor): ) +class TestIndexIndexPutDynamicConverter(DispatchTestCase): + @parameterized.expand( + [ + param( + test_name="1d_indices_single", + indices_tensor=(torch.tensor([0], dtype=torch.int32),), + value_tensor=torch.tensor([1], dtype=torch.float32), + input_min_shape=(1,), + input_opt_shape=(5,), + input_max_shape=(5,), + ), + param( + test_name="1d_indices_multiple", + indices_tensor=(torch.tensor([0, 3], dtype=torch.int32),), + value_tensor=torch.tensor([1, 3], dtype=torch.float32), + input_min_shape=(1,), + input_opt_shape=(5,), + input_max_shape=(5,), + ), + param( + test_name="2d_indices_single", + indices_tensor=( + torch.tensor([2], dtype=torch.int32), + torch.tensor([0], dtype=torch.int32), + ), + value_tensor=torch.tensor([3], dtype=torch.float32), + input_min_shape=(2, 5), + input_opt_shape=(5, 5), + input_max_shape=(5, 5), + ), + param( + test_name="2d_indices_multiple", + indices_tensor=( + torch.tensor([0, 2, 2], dtype=torch.int32), + torch.tensor([2, 0, 2], dtype=torch.int32), + ), + value_tensor=torch.tensor([1, 3, 4], dtype=torch.float32), + input_min_shape=(2, 5), + input_opt_shape=(5, 5), + input_max_shape=(5, 5), + ), + param( + test_name="3d_indices_single", + indices_tensor=( + torch.tensor([1], dtype=torch.int32), + torch.tensor([2], dtype=torch.int32), + torch.tensor([2], dtype=torch.int32), + ), + value_tensor=torch.tensor([7], dtype=torch.float32), + input_min_shape=(2, 3, 3), + input_opt_shape=(3, 3, 3), + input_max_shape=(3, 3, 3), + ), + param( + test_name="3d_indices_multiple", + indices_tensor=( + torch.tensor([0, 1, 1], dtype=torch.int32), + torch.tensor([1, 2, 1], dtype=torch.int32), + torch.tensor([2, 0, 2], dtype=torch.int32), + ), + value_tensor=torch.tensor([5, 7, 2], dtype=torch.float32), + input_min_shape=(2, 3, 3), + input_opt_shape=(3, 3, 3), + input_max_shape=(3, 3, 3), + ), + param( + test_name="4d_indices_single", + indices_tensor=( + torch.tensor([1], dtype=torch.int32), + torch.tensor([1], dtype=torch.int32), + torch.tensor([0], dtype=torch.int32), + torch.tensor([1], dtype=torch.int32), + ), + value_tensor=torch.tensor([5], dtype=torch.float32), + input_min_shape=(1, 2, 2, 2), + input_opt_shape=(2, 2, 2, 2), + input_max_shape=(2, 2, 2, 2), + ), + param( + test_name="4d_indices_multiple", + indices_tensor=( + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([1, 1], dtype=torch.int32), + torch.tensor([1, 0], dtype=torch.int32), + torch.tensor([1, 0], dtype=torch.int32), + ), + value_tensor=torch.tensor([5, 7], dtype=torch.float32), + input_min_shape=(1, 2, 2, 2), + input_opt_shape=(2, 2, 2, 2), + input_max_shape=(2, 2, 2, 2), + ), + param( + test_name="negative_indices", + indices_tensor=( + torch.tensor([-1, -2], dtype=torch.int32), + torch.tensor([2, 0], dtype=torch.int32), + ), + value_tensor=torch.tensor([1, 3], dtype=torch.float32), + input_min_shape=(2, 5), + input_opt_shape=(5, 5), + input_max_shape=(5, 5), + ), + param( + test_name="mixed_indices", + indices_tensor=( + torch.tensor([0, 1, -1, -2], dtype=torch.int32), + torch.tensor([0, -1, 2, 1], dtype=torch.int32), + ), + value_tensor=torch.tensor([2, 4, 6, 8], dtype=torch.float32), + input_min_shape=(2, 4), + input_opt_shape=(4, 4), + input_max_shape=(4, 4), + ), + param( + test_name="1d_indices_float", + indices_tensor=(torch.tensor([0, 3], dtype=torch.int32),), + value_tensor=torch.tensor([1.5, 3.5], dtype=torch.float32), + input_min_shape=(1,), + input_opt_shape=(5,), + input_max_shape=(5,), + ), + param( + test_name="2d_indices_float", + indices_tensor=( + torch.tensor([0, 2], dtype=torch.int32), + torch.tensor([2, 0], dtype=torch.int32), + ), + value_tensor=torch.tensor([1.5, 3.5], dtype=torch.float32), + input_min_shape=(1, 5), + input_opt_shape=(5, 5), + input_max_shape=(5, 5), + ), + param( + test_name="3d_indices_float", + indices_tensor=( + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([1, 2], dtype=torch.int32), + torch.tensor([2, 0], dtype=torch.int32), + ), + value_tensor=torch.tensor([5.5, 7.5], dtype=torch.float32), + input_min_shape=(1, 3, 3), + input_opt_shape=(3, 3, 3), + input_max_shape=(3, 3, 3), + ), + param( + test_name="4d_indices_float", + indices_tensor=( + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([1, 0], dtype=torch.int32), + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([1, 0], dtype=torch.int32), + ), + value_tensor=torch.tensor([5.5, 7.5], dtype=torch.float32), + input_min_shape=(1, 1, 2, 2), + input_opt_shape=(2, 2, 2, 2), + input_max_shape=(2, 2, 2, 2), + ), + ] + ) + def test_index_constant_dynamic( + self, + test_name, + indices_tensor, + value_tensor, + input_min_shape, + input_opt_shape, + input_max_shape, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.ops.aten.index_put.default( + input, indices_tensor, value_tensor, accumulate=False + ) + + input_specs = [ + Input( + min_shape=input_min_shape, + opt_shape=input_opt_shape, + max_shape=input_max_shape, + dtype=torch.float32, + ), + ] + self.run_test_with_dynamic_shape(TestModule(), input_specs) + + if __name__ == "__main__": run_tests()