@@ -259,6 +259,32 @@ def tensor_node_is_bool(node: torch.fx.Node) -> bool:
259259 return False
260260
261261
262+ def ndim_of (node : Any ) -> Optional [int ]:
263+ """
264+ Returns the number of dimensions of the tensor produced by the given node
265+ """
266+ if not is_single_tensor_node (node ):
267+ return None
268+
269+ return node .meta ["val" ].ndim
270+
271+
272+ def is_unsqueezed_vector (node : torch .fx .Node ) -> bool :
273+ """
274+ Returns True if the node's tensor has all dimensions equal to 1 except for the last dimension.
275+ """
276+ if not is_single_tensor_node (node ):
277+ return False
278+
279+ tensor = node .meta ["val" ]
280+ assert isinstance (tensor , FakeTensor )
281+
282+ if len (tensor .shape ) < 1 :
283+ return False
284+ # All dims except last are 1, last can be any size
285+ return all (dim == 1 for dim in tensor .shape [:- 1 ])
286+
287+
262288def op_contains_bool_tensor (node : torch .fx .Node ) -> bool :
263289 """
264290 Returns true if the operator used to compute the given node contains a bool tensor
@@ -267,6 +293,7 @@ def op_contains_bool_tensor(node: torch.fx.Node) -> bool:
267293 return True
268294
269295 for arg_node in node .args :
296+ # pyre-ignore[6]
270297 if is_tensor_node (arg_node ) and tensor_node_is_bool (arg_node ):
271298 return True
272299
@@ -1250,6 +1277,26 @@ def is_in_8bit_range(tensor: torch.Tensor) -> bool:
12501277##
12511278
12521279
1280+ def normalize_dims (dims : Union [int , List [int ]], ndim : int ) -> Union [int , List [int ]]:
1281+ """
1282+ Normalize dimension indices to be non-negative and within [0, ndim).
1283+ Accepts a single int or a list of ints.
1284+ """
1285+ if isinstance (dims , int ):
1286+ if dims < 0 :
1287+ dims += ndim
1288+
1289+ return dims
1290+
1291+ normalized = []
1292+ for d in dims :
1293+ if d < 0 :
1294+ d += ndim
1295+ normalized .append (d )
1296+
1297+ return normalized
1298+
1299+
12531300def nchw_dim_to_whcn_dim (nchw_dim : int , ndim : int ) -> int :
12541301 # Handle negative indices for nchw_dim
12551302 if nchw_dim < 0 :
0 commit comments