@@ -1625,7 +1625,9 @@ class LeafGuard {
1625
1625
// is not exposed to Python and can only be called from C++.
1626
1626
virtual bool check_nopybind (PyObject* value) = 0;
1627
1627
virtual bool check_nopybind (FrameLocalsMapping* map) {
1628
- throw std::runtime_error (" fallback to python" );
1628
+ // throw std::runtime_error("fallback to python");
1629
+ // Could fallback to running check on the Python dict (lazily constructed)
1630
+ return check_nopybind ((PyObject*)map->to_dict ());
1629
1631
}
1630
1632
1631
1633
virtual ~LeafGuard () = default ;
@@ -1656,13 +1658,8 @@ class LAMBDA_GUARD : public LeafGuard {
1656
1658
LAMBDA_GUARD (
1657
1659
RootGuardManager* root_guard_manager,
1658
1660
py::object guard_check_fn,
1659
- py::object required_locals,
1660
- bool construct_partial_framelocals_dict,
1661
1661
py::object verbose_code_parts)
1662
- : LeafGuard(root_guard_manager, std::move(verbose_code_parts)),
1663
- _required_locals (py::cast<py::dict>(required_locals)),
1664
- _construct_partial_framelocals_dict(
1665
- construct_partial_framelocals_dict) {
1662
+ : LeafGuard(root_guard_manager, std::move(verbose_code_parts)) {
1666
1663
if (py::isinstance<py::function>(guard_check_fn)) {
1667
1664
_guard_check_fn = py::cast<py::function>(std::move (guard_check_fn));
1668
1665
} else {
@@ -1699,30 +1696,7 @@ class LAMBDA_GUARD : public LeafGuard {
1699
1696
return GuardDebugInfo (false , verbose_code_parts (), 0 );
1700
1697
}
1701
1698
1702
- bool check_nopybind (FrameLocalsMapping* map) override {
1703
- // TODO (anijain2305) - Get rid of the _construct_partial_framelocals_dict
1704
- // once its stable.
1705
- if (_construct_partial_framelocals_dict) {
1706
- py::dict partial_dict;
1707
-
1708
- for (auto item : _required_locals) {
1709
- partial_dict[item.first ] = map->get (item.second .cast <int >());
1710
- }
1711
-
1712
- return check_nopybind (partial_dict.ptr ());
1713
- }
1714
- return check_nopybind ((PyObject*)map->to_dict ());
1715
- }
1716
-
1717
1699
private:
1718
- // Dict of (local_name, framelocal_idx) representing the minimum number of
1719
- // framelocals needed to construct the dictionary for the lambda guard.
1720
- py::dict _required_locals;
1721
-
1722
- // Temporary flag to allow a fallback behavior. With stability, we can remove
1723
- // this member.
1724
- bool _construct_partial_framelocals_dict;
1725
-
1726
1700
// The user provided lambda function for check_fn.
1727
1701
py::function _guard_check_fn;
1728
1702
};
@@ -1798,12 +1772,7 @@ class LAMBDA_GUARD_NO_FRAMELOCALS : public LAMBDA_GUARD {
1798
1772
RootGuardManager* root_guard_manager,
1799
1773
py::object guard_check_fn,
1800
1774
py::object verbose_code_parts)
1801
- : LAMBDA_GUARD(
1802
- root_guard_manager,
1803
- guard_check_fn,
1804
- py::dict (),
1805
- false,
1806
- verbose_code_parts) {}
1775
+ : LAMBDA_GUARD(root_guard_manager, guard_check_fn, verbose_code_parts) {}
1807
1776
1808
1777
bool check_nopybind (PyObject* value) override { // borrowed ref
1809
1778
return LAMBDA_GUARD::check_nopybind (value);
@@ -6802,8 +6771,7 @@ PyObject* torch_c_dynamo_guards_init() {
6802
6771
.def (" verbose_code_parts" , &LeafGuard::verbose_code_parts);
6803
6772
py::class_<LAMBDA_GUARD, LeafGuard, std::shared_ptr<LAMBDA_GUARD>>(
6804
6773
py_m, " LAMBDA_GUARD" )
6805
- .def (
6806
- py::init<RootGuardManager*, py::function, py::dict, bool , py::list>())
6774
+ .def (py::init<RootGuardManager*, py::function, py::list>())
6807
6775
.def (" __call__" , &LAMBDA_GUARD::check);
6808
6776
py::class_<
6809
6777
LAMBDA_GUARD_NO_ARGS,
@@ -7126,14 +7094,10 @@ PyObject* torch_c_dynamo_guards_init() {
7126
7094
" add_lambda_guard" ,
7127
7095
[](GuardManager& self,
7128
7096
py::object lambda,
7129
- py::object required_locals,
7130
- bool construct_partial_framelocals_dict,
7131
7097
py::object verbose_code_parts) -> void {
7132
7098
self.add_leaf_guard (std::make_shared<LAMBDA_GUARD>(
7133
7099
self.get_root (),
7134
7100
std::move (lambda),
7135
- std::move (required_locals),
7136
- construct_partial_framelocals_dict,
7137
7101
std::move (verbose_code_parts)));
7138
7102
})
7139
7103
.def (
@@ -7816,15 +7780,9 @@ PyObject* torch_c_dynamo_guards_init() {
7816
7780
" add_epilogue_lambda_guard" ,
7817
7781
[](RootGuardManager& self,
7818
7782
py::object lambda,
7819
- py::object required_locals,
7820
- bool construct_partial_framelocals_dict,
7821
7783
py::object verbose_code_parts) -> void {
7822
7784
self.add_epilogue_lambda_guard (std::make_unique<LAMBDA_GUARD>(
7823
- &self,
7824
- std::move (lambda),
7825
- std::move (required_locals),
7826
- construct_partial_framelocals_dict,
7827
- std::move (verbose_code_parts)));
7785
+ &self, std::move (lambda), std::move (verbose_code_parts)));
7828
7786
});
7829
7787
7830
7788
// Dict Guard Manager
0 commit comments