Skip to content

Commit d308fd5

Browse files
ezyangfacebook-github-bot
authored andcommitted
FC preparation for int_oo in PyTorch (#3947)
Summary: Pull Request resolved: #3947 This is needed for pytorch/pytorch#127693 . This code is written so it is compatible before and after this PR. Reviewed By: mergennachin, clee2000 Differential Revision: D58465158 fbshipit-source-id: ca0f2a79eb07e78ff2887f78eb62ff38eeea3ede
1 parent 8f08b8b commit d308fd5

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

exir/passes/sym_shape_eval_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,6 @@ def call(self, graph_module: GraphModule):
259259
"Please use export's constrain_as_size() or constrain_as_value() apis and set a concrete upper bound to resolve this."
260260
)
261261

262-
spec.shape = concrete_shape # pyre-ignore[8]: Attribute `stride` declared in class `TensorSpec` has type `Tuple[int]` but is used as type `List[Optional[int]]`
263-
spec.stride = concrete_stride # pyre-ignore[8]: Attribute `stride` declared in class `TensorSpec` has type `Tuple[int]` but is used as type `List[Optional[int]]`
262+
spec.shape = concrete_shape
263+
spec.stride = concrete_stride # pyre-ignore[8]: Attribute `stride` declared in class `TensorSpec` has type `Tuple[int]` but is used as type `List[int]`
264264
return PassResult(graph_module, True)

exir/sym_util.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def eval_expr(symint: Union[int, torch.SymInt]) -> Optional[int]:
2929
return int(output)
3030

3131

32-
def eval_upper_bound(maybe_symint: Union[int, torch.SymInt]) -> Optional[int]:
32+
def eval_upper_bound(maybe_symint: Union[int, torch.SymInt]) -> int:
3333
"""
3434
Evaluate a symint to its uppper bound value. Returns None if symint's symoblic expr's
3535
upper bound can not be evaluated to valid integer according to the constraints in shape_env.
@@ -41,17 +41,24 @@ def eval_upper_bound(maybe_symint: Union[int, torch.SymInt]) -> Optional[int]:
4141
expr = node.expr
4242
var_range: ValueRanges = bound_sympy(expr, shape_env.var_to_range)
4343
upper_bound = var_range.upper
44+
# This import is needed temporarily until we update the pinned torch version.
45+
46+
try:
47+
from torch.utils._sympy.numbers import int_oo # @manual # pyre-ignore
48+
except ImportError:
49+
int_oo = None
50+
4451
if isinstance(upper_bound, sympy.Integer):
4552
concrete_upper = int(var_range.upper)
4653
assert isinstance(
4754
concrete_upper, int
4855
), f"Expect upper bound to be a concrete int but got {concrete_upper}"
4956
return concrete_upper
50-
elif isinstance(upper_bound, sympy.oo):
51-
return None
57+
elif int_oo is not None and upper_bound is int_oo: # pyre-ignore
58+
return int_oo # pyre-ignore
5259
else:
5360
raise RuntimeError(
54-
f"Expect upper bound to be sympy.Integer or sympy.oo. but got {upper_bound}"
61+
f"Expect upper bound to be sympy.Integer or int_oo. but got {upper_bound}"
5562
)
5663

5764

0 commit comments

Comments
 (0)