22
33import logging
44import operator
5- from typing import Callable , Dict , Optional , Sequence , Tuple , Union
5+ from typing import Any , Callable , Dict , List , Optional , Sequence , Tuple , Union
66
77import numpy as np
88import torch
@@ -218,18 +218,51 @@ def aten_ops_native_group_norm(
218218 )
219219
220220
221+ def parse_cat_args (
222+ args : Tuple [Argument , ...], kwargs : Dict [str , Any ]
223+ ) -> Tuple [List [Any ], int ]:
224+ """
225+ Process inputs for torch.ops.aten.cat.default.
226+
227+ Handles these valid patterns:
228+ 1. args = ((t1, t2, ...), dim)
229+ 2. args = ((t1, t2, ...),), kwargs = {dim: X} with optional dim in kwargs
230+
231+ Returns:
232+ (input_tensors, dim)
233+ input_tensors: tuple of tensor arguments
234+ dim: integer concatenation dimension (default 0)
235+ """
236+
237+ if len (args ) > 1 and isinstance (args [0 ], (list , tuple )):
238+ input_tensors = list (args [0 ])
239+ dim = args_bounds_check (args , 1 , 0 )
240+
241+ else :
242+ # If single arg is itself a tuple/list, unwrap it
243+ if len (args ) == 1 and isinstance (args [0 ], (list , tuple )):
244+ input_tensors = list (args [0 ])
245+ else :
246+ input_tensors = list (args )
247+
248+ dim = kwargs .get ("dim" , 0 )
249+
250+ return input_tensors , dim
251+
252+
221253def cat_validator (node : Node , settings : Optional [CompilationSettings ] = None ) -> bool :
222254 # empty tensor in cat input as ITensor leads to [RemoveDeadLayers] Input Tensor y is unused or used only at compile-time, but is not being removed.
223- for each_input in node .args [0 ]:
255+ inputs , _ = parse_cat_args (node .args , node .kwargs )
256+ for each_input in inputs :
224257 if isinstance (each_input , TRTTensor ) and any (s == 0 for s in each_input .shape ):
225258 return False
226259 return True
227260
228261
229262@dynamo_tensorrt_converter (
230263 torch .ops .aten .cat .default ,
231- capability_validator = cat_validator ,
232264 supports_dynamic_shapes = True ,
265+ capability_validator = cat_validator ,
233266)
234267def aten_ops_cat (
235268 ctx : ConversionContext ,
@@ -238,13 +271,14 @@ def aten_ops_cat(
238271 kwargs : Dict [str , Argument ],
239272 name : str ,
240273) -> Union [TRTTensor , Sequence [TRTTensor ]]:
274+ inputs , dim = parse_cat_args (args , kwargs )
241275 return impl .cat .cat (
242276 ctx ,
243277 target ,
244278 SourceIR .ATEN ,
245279 name ,
246- input = args [ 0 ] ,
247- dim = args_bounds_check ( args , 1 , 0 ) ,
280+ input = inputs ,
281+ dim = dim ,
248282 )
249283
250284
0 commit comments