@@ -1154,22 +1154,6 @@ std::string get_exception_message() {
11541154 return exc_message;
11551155}
11561156
1157- bool is_nn_module (py::handle example_value) {
1158- py::object torch_module_cls = py::module_::import (" torch.nn" ).attr (" Module" );
1159- return py::isinstance (example_value, torch_module_cls);
1160- }
1161-
1162- std::string get_type_str (py::handle example_value) {
1163- std::string type_name;
1164- try {
1165- type_name = py::str (py::type::of (example_value)).cast <std::string>();
1166- } catch (const py::error_already_set& e) {
1167- // Fallback that never throws in release builds
1168- type_name = " <unprintable-type>" ;
1169- }
1170- return type_name;
1171- }
1172-
11731157bool is_immutable_object (py::handle example_value) {
11741158 py::object config_module = py::module_::import (" torch._dynamo.config" );
11751159
@@ -2611,15 +2595,13 @@ class GuardManager {
26112595 : _root(root),
26122596 _source (std::move(source)),
26132597 _is_dict(py::isinstance<py::dict>(example_value)),
2614- _is_immutable(is_immutable_object(example_value)),
2615- _is_nn_module(is_nn_module(example_value)),
2616- _is_tensor(THPVariable_Check(example_value.ptr())),
2617- _type_str(get_type_str(example_value)) {
2598+ _is_immutable(is_immutable_object(example_value)) {
26182599 if (_is_dict) {
26192600 _dict_tag = get_dict_version_unchecked (example_value.ptr ());
2620- _is_empty_dict = PyDict_Size (example_value.ptr ()) == 0 ;
26212601 }
2622-
2602+ py::object typ = py::type::of (example_value);
2603+ py::object weakref_mod = py::module_::import (" weakref" );
2604+ _weak_type = weakref_mod.attr (" ref" )(typ);
26232605 py::object config_module = py::module_::import (" torch._dynamo.config" );
26242606 _max_saved_pointers_for_recursive_dict_tags_check =
26252607 config_module.attr (" max_saved_pointers_for_recursive_dict_tags_check" )
@@ -2681,28 +2663,19 @@ class GuardManager {
26812663 return _is_immutable;
26822664 }
26832665
2684- bool is_guarded_value_nn_module () {
2685- return _is_nn_module;
2686- }
2687-
2688- bool is_guarded_value_dict () {
2689- return _is_dict;
2690- }
2691-
2692- bool is_guarded_value_empty_dict () {
2693- return _is_empty_dict;
2694- }
2695-
2696- bool is_guarded_value_tensor () {
2697- return _is_tensor;
2666+ bool is_recursive_dict_tag_matching_disabled () {
2667+ return _disable_dict_tag_matching;
26982668 }
26992669
2700- std::string type_of_guarded_value () {
2701- return _type_str;
2702- }
2670+ py::object get_type_of_guarded_value () {
2671+ if (!_weak_type || _weak_type.is_none ()) {
2672+ return py::type::of (py::none ());
2673+ }
27032674
2704- bool is_recursive_dict_tag_matching_disabled () {
2705- return _disable_dict_tag_matching;
2675+ if (!PyCallable_Check (_weak_type.ptr ())) {
2676+ throw std::runtime_error (" _weak_type is not callable" );
2677+ }
2678+ return _weak_type ();
27062679 }
27072680
27082681 public:
@@ -2748,19 +2721,13 @@ class GuardManager {
27482721 RootGuardManager* root,
27492722 std::string source,
27502723 bool is_dict,
2751- bool is_empty_dict,
27522724 bool is_immutable,
2753- bool is_nn_module,
2754- bool is_tensor,
2755- std::string type_str)
2725+ py::object weak_type)
27562726 : _root(root),
27572727 _source (std::move(source)),
27582728 _is_dict(is_dict),
2759- _is_empty_dict(is_empty_dict),
27602729 _is_immutable(is_immutable),
2761- _is_nn_module(is_nn_module),
2762- _is_tensor(is_tensor),
2763- _type_str(std::move(type_str)) {}
2730+ _weak_type(weak_type) {}
27642731
27652732 void clone_common (
27662733 RootGuardManager* cloned_root,
@@ -2792,14 +2759,7 @@ class GuardManager {
27922759 return nullptr ;
27932760 }
27942761 GuardManager* cloned_mgr = new GuardManager (
2795- cloned_root,
2796- _source,
2797- _is_dict,
2798- _is_empty_dict,
2799- _is_immutable,
2800- _is_nn_module,
2801- _is_tensor,
2802- _type_str);
2762+ cloned_root, _source, _is_dict, _is_immutable, _weak_type);
28032763 if (is_tag_safe ()) {
28042764 cloned_mgr->mark_tag_safe ();
28052765 if (is_tag_safe_root ()) {
@@ -2975,7 +2935,7 @@ class GuardManager {
29752935 // This is a tag safe node, record the dict pointer
29762936 if (_is_dict) {
29772937 record_dict_pointer (_root, value);
2978- } else if (_is_tensor && _has_no_tensor_aliasing_guard) {
2938+ } else if (_has_no_tensor_aliasing_guard) {
29792939 record_tensor_pointer (_root, value);
29802940 }
29812941 }
@@ -3285,11 +3245,7 @@ class GuardManager {
32853245 bool _has_no_tensor_aliasing_guard = false ;
32863246
32873247 bool _is_dict = false ;
3288- bool _is_empty_dict = false ;
32893248 bool _is_immutable = false ;
3290- bool _is_nn_module = false ;
3291- bool _is_tensor = false ;
3292- std::string _type_str;
32933249 uint64_t _dict_tag{0 };
32943250 uint64_t _max_saved_pointers_for_recursive_dict_tags_check = 0 ;
32953251
@@ -3301,6 +3257,11 @@ class GuardManager {
33013257 _dict_pointers;
33023258 std::unordered_map<PyObject*, std::vector<PyObject*>> _tensor_pointers;
33033259 std::vector<WeakEntry> _tag_safe_entries;
3260+
3261+ protected:
3262+ // weakref to the type of guarded value
3263+ // protected because it is used for cloning by DictGuardManager
3264+ py::object _weak_type;
33043265};
33053266
33063267GuardAccessor::GuardAccessor (
@@ -3873,17 +3834,13 @@ class DictGuardManager : public GuardManager {
38733834 PyTypeObject* expected_type,
38743835 bool is_exact_dict_type,
38753836 std::vector<Py_ssize_t> indices,
3876- std::string type_of,
3877- bool is_empty_dict)
3837+ py::object weak_type)
38783838 : GuardManager(
38793839 cloned_root,
38803840 std::move (source),
38813841 true, // _is_dict
3882- is_empty_dict,
38833842 false, // _is_immutable
3884- false, // _is_nn_module
3885- false, // _is_tensor
3886- std::move(type_of)),
3843+ weak_type),
38873844 _size(size),
38883845 _expected_type(expected_type),
38893846 _is_exact_dict_type(is_exact_dict_type),
@@ -3903,8 +3860,7 @@ class DictGuardManager : public GuardManager {
39033860 _expected_type,
39043861 _is_exact_dict_type,
39053862 _indices,
3906- type_of_guarded_value (),
3907- is_guarded_value_empty_dict ());
3863+ _weak_type);
39083864 if (is_tag_safe ()) {
39093865 cloned_mgr->mark_tag_safe ();
39103866 if (is_tag_safe_root ()) {
@@ -6752,23 +6708,16 @@ PyObject* torch_c_dynamo_guards_init() {
67526708 .def (
67536709 " is_guarded_value_immutable" ,
67546710 &GuardManager::is_guarded_value_immutable)
6755- .def (
6756- " is_guarded_value_nn_module" ,
6757- &GuardManager::is_guarded_value_nn_module)
6758- .def (" is_guarded_value_dict" , &GuardManager::is_guarded_value_dict)
6759- .def (
6760- " is_guarded_value_empty_dict" ,
6761- &GuardManager::is_guarded_value_empty_dict)
6762- .def (" is_guarded_value_tensor" , &GuardManager::is_guarded_value_tensor)
67636711 .def (" has_no_accessors" , &GuardManager::has_no_accessors)
67646712 .def (" mark_tag_safe" , &GuardManager::mark_tag_safe)
67656713 .def (" mark_tag_safe_root" , &GuardManager::mark_tag_safe_root)
67666714 .def (" is_tag_safe" , &GuardManager::is_tag_safe)
67676715 .def (" is_tag_safe_root" , &GuardManager::is_tag_safe_root)
6768- .def (" type_of_guarded_value" , &GuardManager::type_of_guarded_value)
67696716 .def (
67706717 " is_recursive_dict_tag_matching_disabled" ,
67716718 &GuardManager::is_recursive_dict_tag_matching_disabled)
6719+ .def (
6720+ " get_type_of_guarded_value" , &GuardManager::get_type_of_guarded_value)
67726721 .def (
67736722 " get_accessors" ,
67746723 &GuardManager::get_accessors,
0 commit comments