Skip to content

Commit 53e47af

Browse files
anijain2305pytorchmergebot
authored andcommitted
[dynamo][guards] Read the attr name from GetAttrGuardAccessor (pytorch#159754)
Pull Request resolved: pytorch#159754 Approved by: https://github.com/jansel ghstack dependencies: pytorch#159752
1 parent 66ad881 commit 53e47af

File tree

3 files changed

+15
-2
lines changed

3 files changed

+15
-2
lines changed

test/dynamo/test_guard_manager.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,6 +1008,12 @@ def hook(guard_wrapper, f_locals, builder):
10081008
from torch._dynamo.source import AttrSource, LocalSource
10091009

10101010
foo_source = LocalSource("foo")
1011+
foo_mgr = builder.get_guard_manager_from_source(foo_source)
1012+
for accessor in foo_mgr.get_accessors():
1013+
if isinstance(accessor, GetAttrGuardAccessor):
1014+
self.assertTrue(
1015+
accessor.get_attr_name() in ("a", "b", "c", "d", "e")
1016+
)
10111017

10121018
# Check types of foo.a
10131019
foo_a_source = AttrSource(foo_source, "a")

torch/_C/_dynamo/guards.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ class GetGenericDictGuardAccessor(GuardAccessor): ...
142142
class TypeDictGuardAccessor(GuardAccessor): ...
143143
class TypeMROGuardAccessor(GuardAccessor): ...
144144

145+
class GetAttrGuardAccessor(GuardAccessor):
146+
def get_attr_name(self) -> str: ...
147+
145148
def install_object_aliasing_guard(
146149
guard_managers: list[GuardManager],
147150
tensor_names: list[str],

torch/csrc/dynamo/guards.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4244,6 +4244,10 @@ class GetAttrGuardAccessor : public GuardAccessor {
42444244
")";
42454245
}
42464246

4247+
std::string get_attr_name() {
4248+
return py::str(_attr_name).cast<std::string>();
4249+
}
4250+
42474251
public: // cloning functions
42484252
GetAttrGuardAccessor(GuardManager* guard_manager, GetAttrGuardAccessor* from)
42494253
: GuardAccessor(guard_manager, from) {
@@ -6584,11 +6588,11 @@ PyObject* torch_c_dynamo_guards_init() {
65846588
py::class_<GuardAccessor, std::unique_ptr<GuardAccessor>>(
65856589
py_m, "GuardAccessor")
65866590
.def("repr", &GuardAccessor::repr);
6587-
// NOLINTNEXTLINE(bugprone-unused-raii)
65886591
py::class_<
65896592
GetAttrGuardAccessor,
65906593
GuardAccessor,
6591-
std::unique_ptr<GetAttrGuardAccessor>>(py_m, "GetAttrGuardAccessor");
6594+
std::unique_ptr<GetAttrGuardAccessor>>(py_m, "GetAttrGuardAccessor")
6595+
.def("get_attr_name", &GetAttrGuardAccessor::get_attr_name);
65926596
// NOLINTNEXTLINE(bugprone-unused-raii)
65936597
py::class_<
65946598
GenericGetAttrGuardAccessor,

0 commit comments

Comments
 (0)