From 20d63781f1313caaa6e5896684d7faf38841200d Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Tue, 14 Jan 2025 13:13:15 -0800 Subject: [PATCH] Make executorch stride sort size oblivious (#7657) Summary: This is needed to unblock https://github.com/pytorch/pytorch/pull/144695 as that PR triggers a failure in executorch test suite Reviewed By: angelayi Differential Revision: D68169867 --- exir/tensor.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/exir/tensor.py b/exir/tensor.py index e42bf738056..1345067354f 100644 --- a/exir/tensor.py +++ b/exir/tensor.py @@ -13,7 +13,7 @@ import math import typing -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, NamedTuple, Optional, Tuple, Union import executorch.exir.schema as schema import torch @@ -70,8 +70,29 @@ def dim_order_from_stride(stride: Tuple[int]) -> Tuple[bytes]: for _, s in enumerate(stride): if s == 0: raise ValueError("0 in strides is not supported for ExecuTorch.") + + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + class K(NamedTuple): + stride: int + + def __lt__(self, other): + return guard_size_oblivious(self.stride < other.stride) + + def __gt__(self, other): + return guard_size_oblivious(self.stride > other.stride) + + def __le__(self, other): + return guard_size_oblivious(self.stride <= other.stride) + + def __ge__(self, other): + return guard_size_oblivious(self.stride >= other.stride) + + def __eq__(self, other): + return guard_size_oblivious(self.stride == other.stride) + sorted_dims = [ - i[0] for i in sorted(enumerate(stride), key=lambda x: x[1], reverse=True) + i[0] for i in sorted(enumerate(stride), key=lambda x: K(x[1]), reverse=True) ] return tuple(typing.cast(Tuple[bytes], sorted_dims))