Skip to content

Commit d479852

Browse files
Pian Pawakapanfacebook-github-bot
authored andcommitted
guard_or_false in dim_order == 0 check
Differential Revision: D78585234
1 parent 439bb6c commit d479852

File tree

2 files changed

+32
-3
lines changed

2 files changed

+32
-3
lines changed

exir/tensor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,12 @@ def dim_order_from_stride(stride: Tuple[int]) -> Tuple[bytes]:
6767
Another example is: sizes = (1, 3, 1, 1) with strides = (3, 1, 3, 3), returned
6868
value is (0, 2, 3, 1)
6969
"""
70+
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious, guard_or_false
71+
7072
for _, s in enumerate(stride):
71-
if s == 0:
73+
if guard_or_false(s == 0):
7274
raise ValueError("0 in strides is not supported for ExecuTorch.")
7375

74-
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
75-
7676
class K(NamedTuple):
7777
stride: int
7878

exir/tests/test_serde.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,35 @@ def forward(self, x):
263263
== edge_deserialized.to_executorch().buffer
264264
)
265265

266+
def test_dim_order_from_stride(self):
267+
from executorch.exir import EdgeCompileConfig
268+
269+
class Test(torch.nn.Module):
270+
def __init__(self):
271+
super().__init__()
272+
273+
def forward(self, t1, t2):
274+
idx = torch.nonzero(t1).reshape(-1)
275+
y = torch.index_select(t2, 0, idx)
276+
return y
277+
278+
279+
M = Test()
280+
x = torch.tensor([0, 1, 1, 0, 1], dtype=torch.bool)
281+
y = torch.randn(5, 6)
282+
M(x, y)
283+
284+
expo_prog = torch.export.export_for_training(M, (x, y))
285+
print(expo_prog)
286+
287+
edge_prog = to_edge_transform_and_lower(
288+
expo_prog,
289+
partitioner=[XnnpackFloatingPointPartitioner()],
290+
compile_config=EdgeCompileConfig(_check_ir_validity=False, _use_edge_ops=True),
291+
)
292+
edge_prog.to_executorch()
293+
breakpoint()
294+
266295
def test_meta_stack_trace_module_hierarchy(self) -> None:
267296
class Model(nn.Module):
268297
def __init__(self):

0 commit comments

Comments
 (0)