Skip to content

Commit 7cdf504

Browse files
aorenstepytorchmergebot
authored andcommitted
Fix evaluate_expr to include suppress_guards_tls in cache key (pytorch#152661)
ShapeEnv.evaluate_expr() behaves differently based on the (tls) global "suppress_guards" - so its cache key needs to include that value. This came up because pytorch#152662 triggered it in the test `test/dynamo/test_exc.py::ExcTests::test_trigger_bisect_on_error` - fixing this caused that test to work again. Pull Request resolved: pytorch#152661 Approved by: https://github.com/laithsakka
1 parent 30a3c5d commit 7cdf504

File tree

2 files changed

+38
-10
lines changed

2 files changed

+38
-10
lines changed

torch/fx/experimental/recording.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,10 @@ def assert_equal(old: Optional[ShapeEnv], new: ShapeEnv) -> ShapeEnv:
214214
# save_tracked_fakes: saves a snapshot of the TrackedFake list.
215215
# This is used when calling ShapeEnv.produce_guards at arbitrary points in time.
216216
#
217+
# name: the name of the function being recorded. Normally (and by default) this
218+
# is taken from the decorated function but can be set if you need to override
219+
# it.
220+
#
217221
# When to save the list of TrackedFake?
218222
# =====================================
219223
# We should save the list of TrackedFake whenever the translation validation
@@ -225,15 +229,19 @@ def assert_equal(old: Optional[ShapeEnv], new: ShapeEnv) -> ShapeEnv:
225229
# At the moment, there are 2 methods that save the list:
226230
# - ShapeEnv.evaluate_expr
227231
# - ShapeEnv.defer_runtime_assert
228-
def record_shapeenv_event(*, save_tracked_fakes: bool = False) -> Callable:
232+
def record_shapeenv_event(
233+
*, save_tracked_fakes: bool = False, name: Optional[str] = None
234+
) -> Callable:
229235
def decorator(fn: Callable) -> Callable:
230236
assert callable(fn)
231237
args = inspect.getfullargspec(fn).args
232238
assert args and args[0] == "self", (
233239
"record_shapeenv_event should only wrap methods on ShapeEnv; refactor your "
234240
"code so that it calls into a method on ShapeEnv"
235241
)
236-
name = fn.__name__
242+
nonlocal name
243+
if name is None:
244+
name = fn.__name__
237245

238246
@functools.wraps(fn)
239247
def wrapper(*args, **kwargs):
@@ -281,7 +289,11 @@ def retlog(r):
281289
)
282290
# Record the event for 'fn'.
283291
event = ShapeEnvEvent(
284-
fn, list(args), kwargs, tracked_fakes, name=fn.__name__
292+
fn,
293+
list(args),
294+
kwargs,
295+
tracked_fakes,
296+
name=name,
285297
)
286298
# Play the event on this ShapeEnv.
287299
# NB: It's important to put the event first, because running

torch/fx/experimental/symbolic_shapes.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3936,7 +3936,8 @@ def _add_fx_node_metadata(self, node: torch.fx.Node) -> None:
39363936
node.meta[SHAPEENV_EVENT_KEY] = self._last_event_index()
39373937
node.meta[CURRENT_NODE_KEY] = get_current_node()
39383938

3939-
def _suppress_guards_tls(self) -> bool:
3939+
@staticmethod
3940+
def _suppress_guards_tls() -> bool:
39403941
return getattr(TLS, "suppress_guards", False)
39413942

39423943
@record_shapeenv_event()
@@ -6811,8 +6812,6 @@ def evaluate_sym_node(
68116812
sym_node.expr, sym_node.hint, sym_node.fx_node, size_oblivious
68126813
)
68136814

6814-
@lru_cache(256)
6815-
@record_shapeenv_event(save_tracked_fakes=True)
68166815
def evaluate_expr(
68176816
self,
68186817
orig_expr: sympy.Basic,
@@ -6821,6 +6820,27 @@ def evaluate_expr(
68216820
size_oblivious: bool = False,
68226821
*,
68236822
forcing_spec: bool = False,
6823+
) -> sympy.Basic:
6824+
"""
6825+
Given an expression, evaluates it, adding guards if necessary
6826+
"""
6827+
6828+
# Add extra state that evaluate_expr() depends on.
6829+
suppress_guards_tls = ShapeEnv._suppress_guards_tls()
6830+
return self._inner_evaluate_expr(
6831+
orig_expr, hint, fx_node, size_oblivious, forcing_spec, suppress_guards_tls
6832+
)
6833+
6834+
@lru_cache(256)
6835+
@record_shapeenv_event(save_tracked_fakes=True, name="evaluate_expr")
6836+
def _inner_evaluate_expr(
6837+
self,
6838+
orig_expr: sympy.Basic,
6839+
hint: Optional[Union[int, bool, float]],
6840+
fx_node: Optional[torch.fx.Node],
6841+
size_oblivious: bool,
6842+
forcing_spec: bool,
6843+
_suppress_guards_tls: bool,
68246844
) -> sympy.Basic:
68256845
try:
68266846
return self._evaluate_expr(
@@ -6852,10 +6872,6 @@ def _evaluate_expr(
68526872
*,
68536873
forcing_spec: bool = False,
68546874
) -> sympy.Basic:
6855-
"""
6856-
Given an expression, evaluates it, adding guards if necessary
6857-
"""
6858-
68596875
# TODO: split conjunctions and evaluate them separately
68606876

68616877
if isinstance(

0 commit comments

Comments
 (0)