Skip to content

Commit 68d41bb

Browse files
mcr229facebook-github-bot
authored andcommitted
update dynamic shape detection
Summary: Updating Dynamic Shape Detection re: #5794 Differential Revision: D68036835
1 parent 9666ee8 commit 68d41bb

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

backends/xnnpack/operators/op_squeeze.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616
XNNStaticReshape,
1717
XNode,
1818
)
19+
from torch.fx.experimental.symbolic_shapes import (
20+
free_symbols,
21+
)
22+
1923
from executorch.backends.xnnpack.utils.utils import check_or_raise, get_input_node
2024

2125

@@ -57,7 +61,7 @@ def define_node(
5761

5862
num_dynamic_dims = 0
5963
for dim in dynamic_shape:
60-
if isinstance(dim, torch.SymInt):
64+
if free_symbols(dim):
6165
num_dynamic_dims += 1
6266
new_shape.append(0)
6367
else:
@@ -119,7 +123,7 @@ def define_node(
119123

120124
num_dynamic_dims = 0
121125
for dim in dynamic_shape:
122-
if isinstance(dim, torch.SymInt):
126+
if free_symbols(dim):
123127
num_dynamic_dims += 1
124128
new_shape.append(0)
125129
else:

exir/backend/utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from executorch.exir.lowered_backend_module import create_submodule_from_nodes
2525
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
26+
from torch.fx.experimental.symbolic_shapes import has_free_symbols
2627
from torch.fx.node import Node
2728
from torch.fx.passes.utils.source_matcher_utils import SourcePartition
2829

@@ -424,10 +425,7 @@ def is_shape_dynamic(node: torch.fx.Node) -> bool:
424425
Check if the node shape is dynamic.
425426
"""
426427

427-
# Shape is dynamic if any of the dimensions don't evaluate to a static value
428-
return "val" in node.meta and any(
429-
isinstance(d, torch.SymInt) for d in node.meta["val"].shape
430-
)
428+
return has_free_symbols(node.meta["val"].shape)
431429

432430

433431
# TODO - style: use templated types

0 commit comments

Comments
 (0)