Skip to content

Commit 7f5c8aa

Browse files
committed
expanding cat converter to address CI error
1 parent 622404b commit 7f5c8aa

File tree

2 files changed

+56
-5
lines changed

2 files changed

+56
-5
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import logging
44
import operator
5-
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
5+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
66

77
import numpy as np
88
import torch
@@ -218,18 +218,51 @@ def aten_ops_native_group_norm(
218218
)
219219

220220

221+
def parse_cat_args(
222+
args: Tuple[Argument, ...], kwargs: Dict[str, Any]
223+
) -> Tuple[List[Any], int]:
224+
"""
225+
Process inputs for torch.ops.aten.cat.default.
226+
227+
Handles these valid patterns:
228+
1. args = ((t1, t2, ...), dim)
229+
2. args = ((t1, t2, ...),), kwargs = {dim: X} with optional dim in kwargs
230+
231+
Returns:
232+
(input_tensors, dim)
233+
input_tensors: tuple of tensor arguments
234+
dim: integer concatenation dimension (default 0)
235+
"""
236+
237+
if len(args) > 1 and isinstance(args[0], (list, tuple)):
238+
input_tensors = list(args[0])
239+
dim = args_bounds_check(args, 1, 0)
240+
241+
else:
242+
# If single arg is itself a tuple/list, unwrap it
243+
if len(args) == 1 and isinstance(args[0], (list, tuple)):
244+
input_tensors = list(args[0])
245+
else:
246+
input_tensors = list(args)
247+
248+
dim = kwargs.get("dim", 0)
249+
250+
return input_tensors, dim
251+
252+
221253
def cat_validator(node: Node, settings: Optional[CompilationSettings] = None) -> bool:
222254
# empty tensor in cat input as ITensor leads to [RemoveDeadLayers] Input Tensor y is unused or used only at compile-time, but is not being removed.
223-
for each_input in node.args[0]:
255+
inputs, _ = parse_cat_args(node.args, node.kwargs)
256+
for each_input in inputs:
224257
if isinstance(each_input, TRTTensor) and any(s == 0 for s in each_input.shape):
225258
return False
226259
return True
227260

228261

229262
@dynamo_tensorrt_converter(
230263
torch.ops.aten.cat.default,
231-
capability_validator=cat_validator,
232264
supports_dynamic_shapes=True,
265+
capability_validator=cat_validator,
233266
)
234267
def aten_ops_cat(
235268
ctx: ConversionContext,
@@ -238,13 +271,14 @@ def aten_ops_cat(
238271
kwargs: Dict[str, Argument],
239272
name: str,
240273
) -> Union[TRTTensor, Sequence[TRTTensor]]:
274+
inputs, dim = parse_cat_args(args, kwargs)
241275
return impl.cat.cat(
242276
ctx,
243277
target,
244278
SourceIR.ATEN,
245279
name,
246-
input=args[0],
247-
dim=args_bounds_check(args, 1, 0),
280+
input=inputs,
281+
dim=dim,
248282
)
249283

250284

tests/py/dynamo/conversion/test_cat_aten.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,23 @@ def forward(self, x, y, z):
2525
inputs,
2626
)
2727

28+
@parameterized.expand(
29+
[
30+
("pos", 1),
31+
("neg", -2),
32+
]
33+
)
34+
def test_cat_dim_in_kwargs(self, _, dim):
35+
class Cat(nn.Module):
36+
def forward(self, x, y, z):
37+
return torch.ops.aten.cat.default((x, y, z), dim=dim)
38+
39+
inputs = [torch.randn(1, 2, 3), torch.randn(1, 1, 3), torch.randn(1, 3, 3)]
40+
self.run_test(
41+
Cat(),
42+
inputs,
43+
)
44+
2845
@parameterized.expand(
2946
[
3047
("pos", 0),

0 commit comments

Comments
 (0)