@@ -98,17 +98,18 @@ class PyTreeTypeRegistry {
9898
9999 using RegistrationPtr = std::shared_ptr<const Registration>;
100100
101- // Register a new custom type. Objects of `cls` will be treated as container node types in
101+ // Register a new custom type. Objects of type `cls` will be treated as container node types in
102102 // PyTrees.
103103 static void Register (const py::object &cls,
104104 const py::function &flatten_func,
105105 const py::function &unflatten_func,
106106 const py::object &path_entry_type,
107107 const std::string ®istry_namespace = " " );
108108
109+ // Unregister a previously registered custom type.
109110 static void Unregister (const py::object &cls, const std::string ®istry_namespace = " " );
110111
111- // Find the custom type registration for `type`. Returns nullptr if none exists.
112+ // Find the custom type registration for `type`. Return nullptr if none exists.
112113 template <bool NoneIsLeaf>
113114 [[nodiscard]] static RegistrationPtr Lookup (const py::object &cls,
114115 const std::string ®istry_namespace);
@@ -136,13 +137,18 @@ class PyTreeTypeRegistry {
136137 [[nodiscard]] static RegistrationPtr UnregisterImpl (const py::object &cls,
137138 const std::string ®istry_namespace);
138139
139- // Clear the registry on cleanup.
140+ // Clear the registry on cleanup for the current interpreter .
140141 static void Clear ();
141142
142- std::unordered_map<py::handle, RegistrationPtr> m_registrations{};
143- std::unordered_map<std::pair<std::string, py::handle>, RegistrationPtr> m_named_registrations{};
143+ using RegistrationsMap = std::unordered_map<py::handle, RegistrationPtr>;
144+ using NamedRegistrationsMap =
145+ std::unordered_map<std::pair<std::string, py::handle>, RegistrationPtr>;
146+ using BuiltinsTypesSet = std::unordered_set<py::handle>;
144147
145- static inline std::unordered_set<py::handle> sm_builtins_types{};
148+ RegistrationsMap m_registrations{};
149+ NamedRegistrationsMap m_named_registrations{};
150+
151+ static inline BuiltinsTypesSet sm_builtins_types{};
146152 static inline read_write_mutex sm_mutex{};
147153};
148154
0 commit comments