Skip to content

Commit 66ad881

Browse files
anijain2305pytorchmergebot
authored andcommitted
[dynamo][guards][refactor] Simplify type extraction from GuardManager (pytorch#159752)
Pull Request resolved: pytorch#159752 Approved by: https://github.com/jansel
1 parent 1d3eef2 commit 66ad881

File tree

3 files changed

+39
-88
lines changed

3 files changed

+39
-88
lines changed

test/dynamo/test_guard_manager.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -931,7 +931,7 @@ def hook(guard_wrapper, f_locals, builder):
931931

932932
# Check types of foo.x
933933
foo_x_mgr = builder.get_guard_manager_from_source(foo_x_source)
934-
self.assertTrue(foo_x_mgr.is_guarded_value_dict())
934+
self.assertTrue(issubclass(foo_x_mgr.get_type_of_guarded_value(), dict))
935935

936936
# Check types of foo.x["a"]
937937
foo_x_a_source = DictGetItemSource(foo_x_source, "a")
@@ -946,12 +946,14 @@ def hook(guard_wrapper, f_locals, builder):
946946
# Check types of foo.z
947947
foo_z_source = AttrSource(foo_source, "z")
948948
foo_z_mgr = builder.get_guard_manager_from_source(foo_z_source)
949-
self.assertTrue(foo_z_mgr.is_guarded_value_empty_dict())
949+
self.assertTrue(issubclass(foo_z_mgr.get_type_of_guarded_value(), dict))
950950

951951
# Check types of mod
952952
mod_source = LocalSource("mod")
953953
mod_mgr = builder.get_guard_manager_from_source(mod_source)
954-
self.assertTrue(mod_mgr.is_guarded_value_nn_module())
954+
self.assertTrue(
955+
issubclass(mod_mgr.get_type_of_guarded_value(), torch.nn.Module)
956+
)
955957

956958
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
957959
with install_guard_manager_testing_hook(hook):

torch/_dynamo/guards.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def find_tag_safe_roots(self):
355355
def visit_dict_manager(node):
356356
# Just recurse through the key and value dict managers and check if
357357
# all of them are tag safe nodes.
358-
assert node.is_guarded_value_dict()
358+
assert issubclass(node.get_type_of_guarded_value(), dict)
359359

360360
tag_safe_roots = []
361361
is_subtree_tag_safe = True
@@ -394,12 +394,12 @@ def visit_manager(node):
394394
# If the node guards a tensor, mark it tag safe only if there
395395
# are no accessors. Presence of accessors means presence of
396396
# symbolic shape guards.
397-
if node.is_guarded_value_tensor():
397+
if issubclass(node.get_type_of_guarded_value(), torch.Tensor):
398398
if node.has_no_accessors() and not node.has_object_aliasing_guard():
399399
node.mark_tag_safe()
400400
else:
401401
node.mark_tag_safe()
402-
elif node.is_guarded_value_dict():
402+
elif issubclass(node.get_type_of_guarded_value(), dict):
403403
accessors = node.get_accessors()
404404
child_mgrs = node.get_child_managers()
405405
is_subtree_tag_safe = all(
@@ -408,7 +408,7 @@ def visit_manager(node):
408408
)
409409
if is_subtree_tag_safe:
410410
node.mark_tag_safe()
411-
elif node.is_guarded_value_nn_module():
411+
elif issubclass(node.get_type_of_guarded_value(), torch.nn.Module):
412412
accessors = node.get_accessors()
413413
child_mgrs = node.get_child_managers()
414414
is_subtree_tag_safe = all(
@@ -434,7 +434,7 @@ def visit(node):
434434

435435
tag_safe_roots = visit(self.root)
436436
for node in tag_safe_roots:
437-
if node.is_guarded_value_nn_module():
437+
if issubclass(node.get_type_of_guarded_value(), torch.nn.Module):
438438
node.mark_tag_safe_root()
439439

440440
def populate_diff_guard_manager(self):
@@ -468,7 +468,7 @@ def get_manager_line(self, guard_manager, accessor_str=None):
468468
s = t + ": source=" + source
469469
if accessor_str:
470470
s += ", " + accessor_str
471-
s += f", type={guard_manager.type_of_guarded_value()}"
471+
s += f", type={guard_manager.get_type_of_guarded_value()}"
472472
s += f", tag_safe=({guard_manager.is_tag_safe()}, {guard_manager.is_tag_safe_root()})"
473473
return s
474474

torch/csrc/dynamo/guards.cpp

Lines changed: 28 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,22 +1154,6 @@ std::string get_exception_message() {
11541154
return exc_message;
11551155
}
11561156

1157-
bool is_nn_module(py::handle example_value) {
1158-
py::object torch_module_cls = py::module_::import("torch.nn").attr("Module");
1159-
return py::isinstance(example_value, torch_module_cls);
1160-
}
1161-
1162-
std::string get_type_str(py::handle example_value) {
1163-
std::string type_name;
1164-
try {
1165-
type_name = py::str(py::type::of(example_value)).cast<std::string>();
1166-
} catch (const py::error_already_set& e) {
1167-
// Fallback that never throws in release builds
1168-
type_name = "<unprintable-type>";
1169-
}
1170-
return type_name;
1171-
}
1172-
11731157
bool is_immutable_object(py::handle example_value) {
11741158
py::object config_module = py::module_::import("torch._dynamo.config");
11751159

@@ -2611,15 +2595,13 @@ class GuardManager {
26112595
: _root(root),
26122596
_source(std::move(source)),
26132597
_is_dict(py::isinstance<py::dict>(example_value)),
2614-
_is_immutable(is_immutable_object(example_value)),
2615-
_is_nn_module(is_nn_module(example_value)),
2616-
_is_tensor(THPVariable_Check(example_value.ptr())),
2617-
_type_str(get_type_str(example_value)) {
2598+
_is_immutable(is_immutable_object(example_value)) {
26182599
if (_is_dict) {
26192600
_dict_tag = get_dict_version_unchecked(example_value.ptr());
2620-
_is_empty_dict = PyDict_Size(example_value.ptr()) == 0;
26212601
}
2622-
2602+
py::object typ = py::type::of(example_value);
2603+
py::object weakref_mod = py::module_::import("weakref");
2604+
_weak_type = weakref_mod.attr("ref")(typ);
26232605
py::object config_module = py::module_::import("torch._dynamo.config");
26242606
_max_saved_pointers_for_recursive_dict_tags_check =
26252607
config_module.attr("max_saved_pointers_for_recursive_dict_tags_check")
@@ -2681,28 +2663,19 @@ class GuardManager {
26812663
return _is_immutable;
26822664
}
26832665

2684-
bool is_guarded_value_nn_module() {
2685-
return _is_nn_module;
2686-
}
2687-
2688-
bool is_guarded_value_dict() {
2689-
return _is_dict;
2690-
}
2691-
2692-
bool is_guarded_value_empty_dict() {
2693-
return _is_empty_dict;
2694-
}
2695-
2696-
bool is_guarded_value_tensor() {
2697-
return _is_tensor;
2666+
bool is_recursive_dict_tag_matching_disabled() {
2667+
return _disable_dict_tag_matching;
26982668
}
26992669

2700-
std::string type_of_guarded_value() {
2701-
return _type_str;
2702-
}
2670+
py::object get_type_of_guarded_value() {
2671+
if (!_weak_type || _weak_type.is_none()) {
2672+
return py::type::of(py::none());
2673+
}
27032674

2704-
bool is_recursive_dict_tag_matching_disabled() {
2705-
return _disable_dict_tag_matching;
2675+
if (!PyCallable_Check(_weak_type.ptr())) {
2676+
throw std::runtime_error("_weak_type is not callable");
2677+
}
2678+
return _weak_type();
27062679
}
27072680

27082681
public:
@@ -2748,19 +2721,13 @@ class GuardManager {
27482721
RootGuardManager* root,
27492722
std::string source,
27502723
bool is_dict,
2751-
bool is_empty_dict,
27522724
bool is_immutable,
2753-
bool is_nn_module,
2754-
bool is_tensor,
2755-
std::string type_str)
2725+
py::object weak_type)
27562726
: _root(root),
27572727
_source(std::move(source)),
27582728
_is_dict(is_dict),
2759-
_is_empty_dict(is_empty_dict),
27602729
_is_immutable(is_immutable),
2761-
_is_nn_module(is_nn_module),
2762-
_is_tensor(is_tensor),
2763-
_type_str(std::move(type_str)) {}
2730+
_weak_type(weak_type) {}
27642731

27652732
void clone_common(
27662733
RootGuardManager* cloned_root,
@@ -2792,14 +2759,7 @@ class GuardManager {
27922759
return nullptr;
27932760
}
27942761
GuardManager* cloned_mgr = new GuardManager(
2795-
cloned_root,
2796-
_source,
2797-
_is_dict,
2798-
_is_empty_dict,
2799-
_is_immutable,
2800-
_is_nn_module,
2801-
_is_tensor,
2802-
_type_str);
2762+
cloned_root, _source, _is_dict, _is_immutable, _weak_type);
28032763
if (is_tag_safe()) {
28042764
cloned_mgr->mark_tag_safe();
28052765
if (is_tag_safe_root()) {
@@ -2975,7 +2935,7 @@ class GuardManager {
29752935
// This is a tag safe node, record the dict pointer
29762936
if (_is_dict) {
29772937
record_dict_pointer(_root, value);
2978-
} else if (_is_tensor && _has_no_tensor_aliasing_guard) {
2938+
} else if (_has_no_tensor_aliasing_guard) {
29792939
record_tensor_pointer(_root, value);
29802940
}
29812941
}
@@ -3285,11 +3245,7 @@ class GuardManager {
32853245
bool _has_no_tensor_aliasing_guard = false;
32863246

32873247
bool _is_dict = false;
3288-
bool _is_empty_dict = false;
32893248
bool _is_immutable = false;
3290-
bool _is_nn_module = false;
3291-
bool _is_tensor = false;
3292-
std::string _type_str;
32933249
uint64_t _dict_tag{0};
32943250
uint64_t _max_saved_pointers_for_recursive_dict_tags_check = 0;
32953251

@@ -3301,6 +3257,11 @@ class GuardManager {
33013257
_dict_pointers;
33023258
std::unordered_map<PyObject*, std::vector<PyObject*>> _tensor_pointers;
33033259
std::vector<WeakEntry> _tag_safe_entries;
3260+
3261+
protected:
3262+
// weakref to the type of guarded value
3263+
// protected because it is used for cloning by DictGuardManager
3264+
py::object _weak_type;
33043265
};
33053266

33063267
GuardAccessor::GuardAccessor(
@@ -3873,17 +3834,13 @@ class DictGuardManager : public GuardManager {
38733834
PyTypeObject* expected_type,
38743835
bool is_exact_dict_type,
38753836
std::vector<Py_ssize_t> indices,
3876-
std::string type_of,
3877-
bool is_empty_dict)
3837+
py::object weak_type)
38783838
: GuardManager(
38793839
cloned_root,
38803840
std::move(source),
38813841
true, // _is_dict
3882-
is_empty_dict,
38833842
false, // _is_immutable
3884-
false, // _is_nn_module
3885-
false, // _is_tensor
3886-
std::move(type_of)),
3843+
weak_type),
38873844
_size(size),
38883845
_expected_type(expected_type),
38893846
_is_exact_dict_type(is_exact_dict_type),
@@ -3903,8 +3860,7 @@ class DictGuardManager : public GuardManager {
39033860
_expected_type,
39043861
_is_exact_dict_type,
39053862
_indices,
3906-
type_of_guarded_value(),
3907-
is_guarded_value_empty_dict());
3863+
_weak_type);
39083864
if (is_tag_safe()) {
39093865
cloned_mgr->mark_tag_safe();
39103866
if (is_tag_safe_root()) {
@@ -6752,23 +6708,16 @@ PyObject* torch_c_dynamo_guards_init() {
67526708
.def(
67536709
"is_guarded_value_immutable",
67546710
&GuardManager::is_guarded_value_immutable)
6755-
.def(
6756-
"is_guarded_value_nn_module",
6757-
&GuardManager::is_guarded_value_nn_module)
6758-
.def("is_guarded_value_dict", &GuardManager::is_guarded_value_dict)
6759-
.def(
6760-
"is_guarded_value_empty_dict",
6761-
&GuardManager::is_guarded_value_empty_dict)
6762-
.def("is_guarded_value_tensor", &GuardManager::is_guarded_value_tensor)
67636711
.def("has_no_accessors", &GuardManager::has_no_accessors)
67646712
.def("mark_tag_safe", &GuardManager::mark_tag_safe)
67656713
.def("mark_tag_safe_root", &GuardManager::mark_tag_safe_root)
67666714
.def("is_tag_safe", &GuardManager::is_tag_safe)
67676715
.def("is_tag_safe_root", &GuardManager::is_tag_safe_root)
6768-
.def("type_of_guarded_value", &GuardManager::type_of_guarded_value)
67696716
.def(
67706717
"is_recursive_dict_tag_matching_disabled",
67716718
&GuardManager::is_recursive_dict_tag_matching_disabled)
6719+
.def(
6720+
"get_type_of_guarded_value", &GuardManager::get_type_of_guarded_value)
67726721
.def(
67736722
"get_accessors",
67746723
&GuardManager::get_accessors,

0 commit comments

Comments
 (0)