Skip to content

Commit 1302637

Browse files
Revert "[dynamo][guards] Do not construct entire framelocals dict for LAMBDA_GUARD (pytorch#162525)"
This reverts commit 5f630d2. Reverted pytorch#162525 on behalf of https://github.com/anijain2305 due to internal tests fail ([comment](pytorch#162525 (comment)))
1 parent e0bcd58 commit 1302637

File tree

7 files changed

+19
-128
lines changed

7 files changed

+19
-128
lines changed

test/dynamo/test_guard_manager.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,6 @@ def test_python_lambda_leaf_guard(self):
116116
const_guard = guards.LAMBDA_GUARD(
117117
root,
118118
functools.partial(equals_match, expected=5),
119-
{},
120-
False,
121119
equals_match_verbose_code_parts(5),
122120
)
123121
self.assertTrue(const_guard(5))
@@ -407,14 +405,10 @@ def test_guard_manager_leaf_guard(self):
407405
guard_manager.add_type_match_guard(id_type(5), ["type(x) == int"])
408406
guard_manager.add_lambda_guard(
409407
functools.partial(ge_match, expected=5),
410-
{},
411-
False,
412408
ge_match_verbose_code_parts(expected=5),
413409
)
414410
guard_manager.add_lambda_guard(
415411
functools.partial(less_match, expected=10),
416-
{},
417-
False,
418412
less_match_verbose_code_parts(expected=10),
419413
)
420414
self.assertEqual(len(guard_manager.get_leaf_guards()), 3)
@@ -434,14 +428,10 @@ def __init__(self, x, y):
434428
guard_manager.add_type_match_guard(id_type(foo), ["type(x) == Foo"])
435429
guard_manager.getattr_manager("x", "x", 1, default_mgr_enum).add_lambda_guard(
436430
functools.partial(equals_match, expected=foo.x),
437-
{},
438-
False,
439431
equals_match_verbose_code_parts(foo.x),
440432
)
441433
guard_manager.getattr_manager("y", "y", 2, default_mgr_enum).add_lambda_guard(
442434
functools.partial(equals_match, expected=foo.y),
443-
{},
444-
False,
445435
equals_match_verbose_code_parts(foo.y),
446436
)
447437
self.assertEqual(len(guard_manager.get_leaf_guards()), 1)
@@ -484,14 +474,10 @@ def test_item_guard_manager(self):
484474
guard_manager.add_type_match_guard(id_type(foo), ["type(x) == Foo"])
485475
guard_manager.getitem_manager(0, "", 1, default_mgr_enum).add_lambda_guard(
486476
functools.partial(equals_match, expected=foo[0]),
487-
{},
488-
False,
489477
equals_match_verbose_code_parts(foo[0]),
490478
)
491479
guard_manager.getitem_manager(1, "", 2, default_mgr_enum).add_lambda_guard(
492480
functools.partial(equals_match, expected=foo[1]),
493-
{},
494-
False,
495481
equals_match_verbose_code_parts(foo[1]),
496482
)
497483
self.assertEqual(len(guard_manager.get_leaf_guards()), 1)
@@ -599,8 +585,6 @@ def test_globals(self):
599585
lambda x: isinstance(x, Pair)
600586
and isinstance(x.x, torch.Tensor)
601587
and isinstance(x.y, int),
602-
{},
603-
False,
604588
"global guard fail",
605589
)
606590

@@ -651,8 +635,6 @@ def mul(self, x):
651635
)
652636
attr_manager.add_lambda_guard(
653637
lambda x: x == 4,
654-
{},
655-
False,
656638
"Expected value 4",
657639
)
658640

@@ -693,8 +675,6 @@ def test_global_weakref(self):
693675

694676
weakref_manager.add_lambda_guard(
695677
lambda x: isinstance(x, torch.Tensor),
696-
{},
697-
False,
698678
"global weakref fail",
699679
)
700680

@@ -714,8 +694,6 @@ def test_lambda_manager(self):
714694
)
715695
foo_mgr.add_lambda_guard(
716696
lambda x: x == 3,
717-
{},
718-
False,
719697
"Expected value 3",
720698
)
721699
self.assertTrue(guard_manager.check(a))
@@ -801,7 +779,7 @@ def nothing():
801779
# Add key-value manager (nothing : {"z" : 3})
802780
self.assertTrue(root.check(f_locals))
803781
dict_mgr.get_key_manager(1, "", nothing, default_mgr_enum).add_lambda_guard(
804-
lambda x: x is nothing, {}, False, ["x is nothing"]
782+
lambda x: x is nothing, ["x is nothing"]
805783
)
806784
self.assertTrue(root.check(f_locals))
807785
value_mgr = dict_mgr.get_value_manager(

test/dynamo/test_misc.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7207,9 +7207,7 @@ def fn(x):
72077207
return x + 1
72087208

72097209
guard_manager = torch._dynamo.guards.RootGuardManager()
7210-
guard_manager.add_lambda_guard(
7211-
lambda L: isinstance(L["x"], int), {"x": 0}, True, []
7212-
)
7210+
guard_manager.add_lambda_guard(lambda L: isinstance(L["x"], int), [])
72137211

72147212
def injected(x):
72157213
return x + 42
@@ -7234,33 +7232,27 @@ def fn(x):
72347232
return x + 1
72357233

72367234
guard_manager_bool = torch._dynamo.guards.RootGuardManager()
7237-
guard_manager_bool.add_lambda_guard(
7238-
lambda L: isinstance(L["x"], bool), {"x": 0}, True, []
7239-
)
7235+
guard_manager_bool.add_lambda_guard(lambda L: isinstance(L["x"], bool), [])
72407236

72417237
def injected_bool(x: bool):
72427238
return x + 102
72437239

72447240
guard_manager_int = torch._dynamo.guards.RootGuardManager()
7245-
guard_manager_int.add_lambda_guard(
7246-
lambda L: isinstance(L["x"], int), {"x": 0}, True, []
7247-
)
7241+
guard_manager_int.add_lambda_guard(lambda L: isinstance(L["x"], int), [])
72487242

72497243
def injected_int(x: int):
72507244
return x + 42
72517245

72527246
guard_manager_tensor = torch._dynamo.guards.RootGuardManager()
72537247
guard_manager_tensor.add_lambda_guard(
7254-
lambda L: isinstance(L["x"], torch.Tensor), {"x": 0}, True, []
7248+
lambda L: isinstance(L["x"], torch.Tensor), []
72557249
)
72567250

72577251
def injected_tensor(x: torch.Tensor):
72587252
return x + 100
72597253

72607254
guard_manager_str = torch._dynamo.guards.RootGuardManager()
7261-
guard_manager_str.add_lambda_guard(
7262-
lambda L: isinstance(L["x"], str), {"x": 0}, True, []
7263-
)
7255+
guard_manager_str.add_lambda_guard(lambda L: isinstance(L["x"], str), [])
72647256

72657257
def injected_str(x: str):
72667258
return x + "1"
@@ -7337,10 +7329,7 @@ def fn(x):
73377329

73387330
guard_manager_bool = torch._dynamo.guards.RootGuardManager()
73397331
guard_manager_bool.add_lambda_guard(
7340-
lambda L: isinstance(L["x"], bool),
7341-
{"x": 0},
7342-
True,
7343-
["isinstance(L['x'], bool)"],
7332+
lambda L: isinstance(L["x"], bool), ["isinstance(L['x'], bool)"]
73447333
)
73457334

73467335
def injected_bool(x: bool):

torch/_C/_dynamo/guards.pyi

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -222,11 +222,7 @@ class GuardManager:
222222
) -> GuardManager: ...
223223
# Leaf guards
224224
def add_lambda_guard(
225-
self,
226-
user_lambda: Callable[..., Any],
227-
required_locals: dict[str, int],
228-
construct_partial_framelocals_dict: bool,
229-
verbose_code_parts: list[str],
225+
self, user_lambda: Callable[..., Any], verbose_code_parts: list[str]
230226
) -> None: ...
231227
def add_lambda_guard_no_args(
232228
self, user_lambda: Callable[..., Any], verbose_code_parts: list[str]
@@ -359,8 +355,6 @@ class RootGuardManager(GuardManager):
359355
def add_epilogue_lambda_guard(
360356
self,
361357
guard: LeafGuard,
362-
required_locals: dict[str, int],
363-
construct_partial_framelocals_dict: bool,
364358
verbose_code_parts: list[str],
365359
) -> None: ...
366360
def clone_manager(

torch/_dynamo/config.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -381,12 +381,6 @@
381381
# useful for regional compilation.
382382
max_saved_pointers_for_recursive_dict_tags_check = 256
383383

384-
# Controls whether to construct the partial framelocals to dict for lambda
385-
# guards. This is a temporary flag to allow quick fallback behavior in case of
386-
# unexpected issues. Default is True, i.e., we will construct only partial
387-
# dict, a faster version for guards. Set to False to fallback to old behavior.
388-
construct_partial_framelocals_dict = True
389-
390384
# If True, raises exception if TorchDynamo is called with a context manager
391385
raise_on_ctx_manager_usage = True
392386

torch/_dynamo/guards.py

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -235,16 +235,14 @@
235235
)
236236

237237

238-
def get_framelocals_idx(code: types.CodeType, var_name: str) -> Optional[int]:
238+
def get_framelocals_idx(code: types.CodeType, var_name: str) -> int:
239239
# Refer to index in the frame's localsplus directly.
240240
# NOTE: name order for a code object doesn't change.
241241
# NOTE: we need to find the LAST matching index because <= 3.10 contains
242242
# duplicate names in the case of cells: a name can be both local and cell
243243
# and will take up 2 slots of the frame's localsplus. The correct behavior
244244
# is to refer to the cell, which has a higher index.
245245
framelocals_names_reversed = code_framelocals_names_reversed_cached(code)
246-
if var_name not in framelocals_names_reversed:
247-
return None
248246
framelocals_idx = (
249247
len(framelocals_names_reversed) - framelocals_names_reversed.index(var_name) - 1
250248
)
@@ -1362,7 +1360,6 @@ def get_guard_manager_from_source(self, source: Source) -> GuardManager:
13621360
# Use istype instead of isinstance to check for exact type of source.
13631361
if istype(source, LocalSource):
13641362
framelocals_idx = get_framelocals_idx(self.f_code, source.local_name)
1365-
assert framelocals_idx is not None
13661363
out = root_guard_manager.framelocals_manager(
13671364
key=(source.local_name, framelocals_idx),
13681365
source=source_name,
@@ -1758,34 +1755,15 @@ def add_python_lambda_leaf_guard_to_root(
17581755
guards_log.debug("Python shape guard function:\n%s", pycode)
17591756
exec(pycode, globals_for_guard_fn, out)
17601757
guard_fn = out["___make_guard_fn"](*closure_vars.values())
1761-
1762-
required_locals = {}
1763-
all_locals = self.scope["L"].keys()
1764-
for var_name in guard_fn.__code__.co_consts:
1765-
if isinstance(var_name, str) and var_name in all_locals:
1766-
index = get_framelocals_idx(self.f_code, var_name)
1767-
if index is not None:
1768-
required_locals[var_name] = index
1769-
1770-
construct_partial_framelocals_dict = config.construct_partial_framelocals_dict
1771-
17721758
if is_epilogue:
17731759
# Epilogue guards are run after all the other guards have finished.
17741760
# If epilogue guards contain a getattr or getitem access, one of the
17751761
# other guards would fail preventing the epilogue guards to run.
17761762
self.guard_manager.root.add_epilogue_lambda_guard(
1777-
guard_fn,
1778-
required_locals,
1779-
construct_partial_framelocals_dict,
1780-
verbose_code_parts,
1763+
guard_fn, verbose_code_parts
17811764
)
17821765
else:
1783-
self.guard_manager.root.add_lambda_guard(
1784-
guard_fn,
1785-
required_locals,
1786-
construct_partial_framelocals_dict,
1787-
verbose_code_parts,
1788-
)
1766+
self.guard_manager.root.add_lambda_guard(guard_fn, verbose_code_parts)
17891767

17901768
# Warning: use this with care! This lets you access what the current
17911769
# value of the value you are guarding on is. You probably don't want

torch/_dynamo/output_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2074,7 +2074,7 @@ def compile_and_call_fx_graph(
20742074
check_fn_source = inspect.getsource(specialization.check_fn).strip()
20752075
# Required because the LABDA_GUARD API requires a root guard manager
20762076
unused_root_guard_manager = RootGuardManager()
2077-
check_fn = guards.LAMBDA_GUARD_NO_FRAMELOCALS( # type: ignore[attr-defined]
2077+
check_fn = guards.LAMBDA_GUARD( # type: ignore[attr-defined]
20782078
unused_root_guard_manager,
20792079
specialization.check_fn,
20802080
[check_fn_source],

torch/csrc/dynamo/guards.cpp

Lines changed: 7 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1625,7 +1625,9 @@ class LeafGuard {
16251625
// is not exposed to Python and can only be called from C++.
16261626
virtual bool check_nopybind(PyObject* value) = 0;
16271627
virtual bool check_nopybind(FrameLocalsMapping* map) {
1628-
throw std::runtime_error("fallback to python");
1628+
// throw std::runtime_error("fallback to python");
1629+
// Could fallback to running check on the Python dict (lazily constructed)
1630+
return check_nopybind((PyObject*)map->to_dict());
16291631
}
16301632

16311633
virtual ~LeafGuard() = default;
@@ -1656,13 +1658,8 @@ class LAMBDA_GUARD : public LeafGuard {
16561658
LAMBDA_GUARD(
16571659
RootGuardManager* root_guard_manager,
16581660
py::object guard_check_fn,
1659-
py::object required_locals,
1660-
bool construct_partial_framelocals_dict,
16611661
py::object verbose_code_parts)
1662-
: LeafGuard(root_guard_manager, std::move(verbose_code_parts)),
1663-
_required_locals(py::cast<py::dict>(required_locals)),
1664-
_construct_partial_framelocals_dict(
1665-
construct_partial_framelocals_dict) {
1662+
: LeafGuard(root_guard_manager, std::move(verbose_code_parts)) {
16661663
if (py::isinstance<py::function>(guard_check_fn)) {
16671664
_guard_check_fn = py::cast<py::function>(std::move(guard_check_fn));
16681665
} else {
@@ -1699,30 +1696,7 @@ class LAMBDA_GUARD : public LeafGuard {
16991696
return GuardDebugInfo(false, verbose_code_parts(), 0);
17001697
}
17011698

1702-
bool check_nopybind(FrameLocalsMapping* map) override {
1703-
// TODO (anijain2305) - Get rid of the _construct_partial_framelocals_dict
1704-
// once its stable.
1705-
if (_construct_partial_framelocals_dict) {
1706-
py::dict partial_dict;
1707-
1708-
for (auto item : _required_locals) {
1709-
partial_dict[item.first] = map->get(item.second.cast<int>());
1710-
}
1711-
1712-
return check_nopybind(partial_dict.ptr());
1713-
}
1714-
return check_nopybind((PyObject*)map->to_dict());
1715-
}
1716-
17171699
private:
1718-
// Dict of (local_name, framelocal_idx) representing the minimum number of
1719-
// framelocals needed to construct the dictionary for the lambda guard.
1720-
py::dict _required_locals;
1721-
1722-
// Temporary flag to allow a fallback behavior. With stability, we can remove
1723-
// this member.
1724-
bool _construct_partial_framelocals_dict;
1725-
17261700
// The user provided lambda function for check_fn.
17271701
py::function _guard_check_fn;
17281702
};
@@ -1798,12 +1772,7 @@ class LAMBDA_GUARD_NO_FRAMELOCALS : public LAMBDA_GUARD {
17981772
RootGuardManager* root_guard_manager,
17991773
py::object guard_check_fn,
18001774
py::object verbose_code_parts)
1801-
: LAMBDA_GUARD(
1802-
root_guard_manager,
1803-
guard_check_fn,
1804-
py::dict(),
1805-
false,
1806-
verbose_code_parts) {}
1775+
: LAMBDA_GUARD(root_guard_manager, guard_check_fn, verbose_code_parts) {}
18071776

18081777
bool check_nopybind(PyObject* value) override { // borrowed ref
18091778
return LAMBDA_GUARD::check_nopybind(value);
@@ -6802,8 +6771,7 @@ PyObject* torch_c_dynamo_guards_init() {
68026771
.def("verbose_code_parts", &LeafGuard::verbose_code_parts);
68036772
py::class_<LAMBDA_GUARD, LeafGuard, std::shared_ptr<LAMBDA_GUARD>>(
68046773
py_m, "LAMBDA_GUARD")
6805-
.def(
6806-
py::init<RootGuardManager*, py::function, py::dict, bool, py::list>())
6774+
.def(py::init<RootGuardManager*, py::function, py::list>())
68076775
.def("__call__", &LAMBDA_GUARD::check);
68086776
py::class_<
68096777
LAMBDA_GUARD_NO_ARGS,
@@ -7126,14 +7094,10 @@ PyObject* torch_c_dynamo_guards_init() {
71267094
"add_lambda_guard",
71277095
[](GuardManager& self,
71287096
py::object lambda,
7129-
py::object required_locals,
7130-
bool construct_partial_framelocals_dict,
71317097
py::object verbose_code_parts) -> void {
71327098
self.add_leaf_guard(std::make_shared<LAMBDA_GUARD>(
71337099
self.get_root(),
71347100
std::move(lambda),
7135-
std::move(required_locals),
7136-
construct_partial_framelocals_dict,
71377101
std::move(verbose_code_parts)));
71387102
})
71397103
.def(
@@ -7816,15 +7780,9 @@ PyObject* torch_c_dynamo_guards_init() {
78167780
"add_epilogue_lambda_guard",
78177781
[](RootGuardManager& self,
78187782
py::object lambda,
7819-
py::object required_locals,
7820-
bool construct_partial_framelocals_dict,
78217783
py::object verbose_code_parts) -> void {
78227784
self.add_epilogue_lambda_guard(std::make_unique<LAMBDA_GUARD>(
7823-
&self,
7824-
std::move(lambda),
7825-
std::move(required_locals),
7826-
construct_partial_framelocals_dict,
7827-
std::move(verbose_code_parts)));
7785+
&self, std::move(lambda), std::move(verbose_code_parts)));
78287786
});
78297787

78307788
// Dict Guard Manager

0 commit comments

Comments
 (0)