From 0f15e7da579472168578341b7ef4236354e33e5f Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 16 Oct 2025 12:45:11 -0700 Subject: [PATCH 1/4] addressing cat empty tensor case.Fixes gpt2 data distributed example --- .../data_parallel_stable_diffusion.py | 2 - .../dynamo/conversion/aten_ops_converters.py | 12 ++++- .../dynamo/conversion/impl/cat.py | 2 + tests/py/dynamo/conversion/test_cat_aten.py | 54 +++++++++++++++++++ 4 files changed, 67 insertions(+), 3 deletions(-) diff --git a/examples/distributed_inference/data_parallel_stable_diffusion.py b/examples/distributed_inference/data_parallel_stable_diffusion.py index 5c0e3113e5..023d7e8e63 100644 --- a/examples/distributed_inference/data_parallel_stable_diffusion.py +++ b/examples/distributed_inference/data_parallel_stable_diffusion.py @@ -53,7 +53,5 @@ # Assume there are 2 processes (2 devices) with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt: - print("before \n") result = pipe(prompt).images[0] - print("after ") result.save(f"result_{distributed_state.process_index}.png") diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 147813d8e0..81f359fa70 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -218,7 +218,17 @@ def aten_ops_native_group_norm( ) -@dynamo_tensorrt_converter(torch.ops.aten.cat.default, supports_dynamic_shapes=True) +def cat_validator(node: Node, settings: Optional[CompilationSettings] = None) -> bool: + # Validate only one user, which is a getitem node that accesses the first element in the list + for each_input in node.args[0]: + if isinstance(each_input, TRTTensor) and any(s == 0 for s in each_input.shape): + return False + return True + + +@dynamo_tensorrt_converter( + torch.ops.aten.cat.default, supports_dynamic_shapes=True, validator=cat_validator +) def aten_ops_cat( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cat.py b/py/torch_tensorrt/dynamo/conversion/impl/cat.py index 0553d766c1..8f286bc55c 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/cat.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/cat.py @@ -14,6 +14,8 @@ set_layer_name, ) +logger = logging.getLogger(__name__) + def unify_and_concat_trt_tensors( ctx: ConversionContext, diff --git a/tests/py/dynamo/conversion/test_cat_aten.py b/tests/py/dynamo/conversion/test_cat_aten.py index a9e4a45c81..4d7bc02d1f 100644 --- a/tests/py/dynamo/conversion/test_cat_aten.py +++ b/tests/py/dynamo/conversion/test_cat_aten.py @@ -25,6 +25,60 @@ def forward(self, x, y, z): inputs, ) + @parameterized.expand( + [ + ("pos", 0), + ("neg", -3), + ] + ) + def test_cat_with_scalar_inputs(self, _, dim): + # Ensure scalar tensor wrap works + class Cat(nn.Module): + def forward(self, x, y): + # y is a scalar, x is a tensor + return torch.ops.aten.cat.default((x, y), dim) + + x = torch.randn(1, 2, 3, device="cuda") + y = torch.ones_like(x) * 5.0 # simulate scalar broadcast + inputs = [x, y] + self.run_test(Cat(), inputs) + + @parameterized.expand( + [ + ("pos", 0), + ("neg", -3), + ] + ) + def test_cat_with_empty_tensor(self, _, dim): + # Handle empty tensor in concat + class Cat(nn.Module): + def forward(self, x): + y = torch.empty(0, 2, 3, device="cuda") + return torch.ops.aten.cat.default((x, y), dim) + + inputs = [ + torch.randn(1, 2, 3, device="cuda"), + ] + self.run_test(Cat(), inputs) + + @parameterized.expand( + [ + ("pos", 2), + ("neg", -1), + ] + ) + def test_cat_with_different_dtypes(self, _, dim): + # check dtype promotion path in concat + class Cat(nn.Module): + def forward(self, x, y): + return torch.ops.aten.cat.default((x, y), dim) + + inputs = [ + torch.ones(1, 2, 3, dtype=torch.float32, device="cuda"), + torch.ones(1, 2, 3, dtype=torch.float16, device="cuda"), + ] + self.run_test(Cat(), inputs) + @parameterized.expand( [ ("pos", 1), From 622404b6856ff95cf6bfa301b26a830ee980c64b Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 16 Oct 2025 12:50:19 -0700 Subject: [PATCH 2/4] correcting the validator error message --- py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 81f359fa70..eca7e19924 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -219,7 +219,7 @@ def aten_ops_native_group_norm( def cat_validator(node: Node, settings: Optional[CompilationSettings] = None) -> bool: - # Validate only one user, which is a getitem node that accesses the first element in the list + # empty tensor in cat input as ITensor leads to [RemoveDeadLayers] Input Tensor y is unused or used only at compile-time, but is not being removed. for each_input in node.args[0]: if isinstance(each_input, TRTTensor) and any(s == 0 for s in each_input.shape): return False @@ -227,7 +227,9 @@ def cat_validator(node: Node, settings: Optional[CompilationSettings] = None) -> @dynamo_tensorrt_converter( - torch.ops.aten.cat.default, supports_dynamic_shapes=True, validator=cat_validator + torch.ops.aten.cat.default, + capability_validator=cat_validator, + supports_dynamic_shapes=True, ) def aten_ops_cat( ctx: ConversionContext, From 7f5c8aabc78c5cc3069dd80dbfb72617dd0942e2 Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 17 Oct 2025 13:45:34 -0700 Subject: [PATCH 3/4] expanding cat converter to address CI error --- .../dynamo/conversion/aten_ops_converters.py | 44 ++++++++++++++++--- tests/py/dynamo/conversion/test_cat_aten.py | 17 +++++++ 2 files changed, 56 insertions(+), 5 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index eca7e19924..e91d18298c 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2,7 +2,7 @@ import logging import operator -from typing import Callable, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -218,9 +218,42 @@ def aten_ops_native_group_norm( ) +def parse_cat_args( + args: Tuple[Argument, ...], kwargs: Dict[str, Any] +) -> Tuple[List[Any], int]: + """ + Process inputs for torch.ops.aten.cat.default. + + Handles these valid patterns: + 1. args = ((t1, t2, ...), dim) + 2. args = ((t1, t2, ...),), kwargs = {dim: X} with optional dim in kwargs + + Returns: + (input_tensors, dim) + input_tensors: tuple of tensor arguments + dim: integer concatenation dimension (default 0) + """ + + if len(args) > 1 and isinstance(args[0], (list, tuple)): + input_tensors = list(args[0]) + dim = args_bounds_check(args, 1, 0) + + else: + # If single arg is itself a tuple/list, unwrap it + if len(args) == 1 and isinstance(args[0], (list, tuple)): + input_tensors = list(args[0]) + else: + input_tensors = list(args) + + dim = kwargs.get("dim", 0) + + return input_tensors, dim + + def cat_validator(node: Node, settings: Optional[CompilationSettings] = None) -> bool: # empty tensor in cat input as ITensor leads to [RemoveDeadLayers] Input Tensor y is unused or used only at compile-time, but is not being removed. - for each_input in node.args[0]: + inputs, _ = parse_cat_args(node.args, node.kwargs) + for each_input in inputs: if isinstance(each_input, TRTTensor) and any(s == 0 for s in each_input.shape): return False return True @@ -228,8 +261,8 @@ def cat_validator(node: Node, settings: Optional[CompilationSettings] = None) -> @dynamo_tensorrt_converter( torch.ops.aten.cat.default, - capability_validator=cat_validator, supports_dynamic_shapes=True, + capability_validator=cat_validator, ) def aten_ops_cat( ctx: ConversionContext, @@ -238,13 +271,14 @@ def aten_ops_cat( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: + inputs, dim = parse_cat_args(args, kwargs) return impl.cat.cat( ctx, target, SourceIR.ATEN, name, - input=args[0], - dim=args_bounds_check(args, 1, 0), + input=inputs, + dim=dim, ) diff --git a/tests/py/dynamo/conversion/test_cat_aten.py b/tests/py/dynamo/conversion/test_cat_aten.py index 4d7bc02d1f..15aa8b0d80 100644 --- a/tests/py/dynamo/conversion/test_cat_aten.py +++ b/tests/py/dynamo/conversion/test_cat_aten.py @@ -25,6 +25,23 @@ def forward(self, x, y, z): inputs, ) + @parameterized.expand( + [ + ("pos", 1), + ("neg", -2), + ] + ) + def test_cat_dim_in_kwargs(self, _, dim): + class Cat(nn.Module): + def forward(self, x, y, z): + return torch.ops.aten.cat.default((x, y, z), dim=dim) + + inputs = [torch.randn(1, 2, 3), torch.randn(1, 1, 3), torch.randn(1, 3, 3)] + self.run_test( + Cat(), + inputs, + ) + @parameterized.expand( [ ("pos", 0), From 5ce93460590a710822a74045f0f5dac79682b34c Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 6 Nov 2025 10:53:33 -0800 Subject: [PATCH 4/4] undoing the empty changes, keeping the cat different case handling --- .../dynamo/conversion/aten_ops_converters.py | 10 ------ tests/py/dynamo/conversion/test_cat_aten.py | 36 ------------------- 2 files changed, 46 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index e91d18298c..f3eed29727 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -250,19 +250,9 @@ def parse_cat_args( return input_tensors, dim -def cat_validator(node: Node, settings: Optional[CompilationSettings] = None) -> bool: - # empty tensor in cat input as ITensor leads to [RemoveDeadLayers] Input Tensor y is unused or used only at compile-time, but is not being removed. - inputs, _ = parse_cat_args(node.args, node.kwargs) - for each_input in inputs: - if isinstance(each_input, TRTTensor) and any(s == 0 for s in each_input.shape): - return False - return True - - @dynamo_tensorrt_converter( torch.ops.aten.cat.default, supports_dynamic_shapes=True, - capability_validator=cat_validator, ) def aten_ops_cat( ctx: ConversionContext, diff --git a/tests/py/dynamo/conversion/test_cat_aten.py b/tests/py/dynamo/conversion/test_cat_aten.py index 15aa8b0d80..f823c6f565 100644 --- a/tests/py/dynamo/conversion/test_cat_aten.py +++ b/tests/py/dynamo/conversion/test_cat_aten.py @@ -60,42 +60,6 @@ def forward(self, x, y): inputs = [x, y] self.run_test(Cat(), inputs) - @parameterized.expand( - [ - ("pos", 0), - ("neg", -3), - ] - ) - def test_cat_with_empty_tensor(self, _, dim): - # Handle empty tensor in concat - class Cat(nn.Module): - def forward(self, x): - y = torch.empty(0, 2, 3, device="cuda") - return torch.ops.aten.cat.default((x, y), dim) - - inputs = [ - torch.randn(1, 2, 3, device="cuda"), - ] - self.run_test(Cat(), inputs) - - @parameterized.expand( - [ - ("pos", 2), - ("neg", -1), - ] - ) - def test_cat_with_different_dtypes(self, _, dim): - # check dtype promotion path in concat - class Cat(nn.Module): - def forward(self, x, y): - return torch.ops.aten.cat.default((x, y), dim) - - inputs = [ - torch.ones(1, 2, 3, dtype=torch.float32, device="cuda"), - torch.ones(1, 2, 3, dtype=torch.float16, device="cuda"), - ] - self.run_test(Cat(), inputs) - @parameterized.expand( [ ("pos", 1),