@@ -350,14 +350,28 @@ def is_slice_view(self, node: torch.fx.Node) -> bool:
350
350
def is_cat_along_outermost_dim (
351
351
self , graph_module : torch .fx .GraphModule , cat_node : torch .fx .Node
352
352
) -> bool :
353
+ assert len (cat_node .args ) > 0
354
+ cat_tensors = cat_node .args [0 ]
355
+ if not isinstance (cat_tensors , Sequence ) or not all (
356
+ isinstance (t , torch .fx .Node ) for t in cat_tensors
357
+ ):
358
+ raise ValueError ("cat_tensors must be a sequence of torch.fx.Node objects." )
359
+
360
+ if len (cat_node .args ) > 1 :
361
+ cat_dim = cat_node .args [1 ]
362
+ else :
363
+ cat_dim = cat_node .kwargs .get ("dim" , None )
364
+ if not isinstance (cat_dim , int ):
365
+ raise ValueError ("cat_dim must be an integer." )
366
+
353
367
# If the cat op has default dim, then the concat dim is 0
354
- if len (cat_node . args ) == 1 or cat_node . args [ 1 ] == 0 :
368
+ if len (cat_tensors ) == 1 or cat_dim == 0 :
355
369
return True
356
- # Get the concatenation dimension and concatenated tensors
357
- (cat_tensors , cat_dim ) = cast (
358
- tuple [Sequence [torch .fx .Node ], int ], cat_node .args
359
- )
370
+
371
+ # Make sure all dimes before cat_dim are 1.
360
372
for tensor in cat_tensors :
373
+ if not isinstance (tensor , torch .fx .Node ):
374
+ continue
361
375
shape = get_shape (graph_module , tensor )
362
376
if shape is None or not all (dim == 1 for dim in shape [0 :cat_dim ]):
363
377
return False
0 commit comments