@@ -38,8 +38,11 @@ PyGlobals::PyGlobals() {
3838PyGlobals::~PyGlobals () { instance = nullptr ; }
3939
4040bool PyGlobals::loadDialectModule (llvm::StringRef dialectNamespace) {
41- if (loadedDialectModules.contains (dialectNamespace))
42- return true ;
41+ {
42+ nb::ft_lock_guard lock (mutex);
43+ if (loadedDialectModules.contains (dialectNamespace))
44+ return true ;
45+ }
4346 // Since re-entrancy is possible, make a copy of the search prefixes.
4447 std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
4548 nb::object loaded = nb::none ();
@@ -62,12 +65,14 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
6265 return false ;
6366 // Note: Iterator cannot be shared from prior to loading, since re-entrancy
6467 // may have occurred, which may do anything.
68+ nb::ft_lock_guard lock (mutex);
6569 loadedDialectModules.insert (dialectNamespace);
6670 return true ;
6771}
6872
6973void PyGlobals::registerAttributeBuilder (const std::string &attributeKind,
7074 nb::callable pyFunc, bool replace) {
75+ nb::ft_lock_guard lock (mutex);
7176 nb::object &found = attributeBuilderMap[attributeKind];
7277 if (found && !replace) {
7378 throw std::runtime_error ((llvm::Twine (" Attribute builder for '" ) +
@@ -81,6 +86,7 @@ void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
8186
8287void PyGlobals::registerTypeCaster (MlirTypeID mlirTypeID,
8388 nb::callable typeCaster, bool replace) {
89+ nb::ft_lock_guard lock (mutex);
8490 nb::object &found = typeCasterMap[mlirTypeID];
8591 if (found && !replace)
8692 throw std::runtime_error (" Type caster is already registered with caster: " +
@@ -90,6 +96,7 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
9096
9197void PyGlobals::registerValueCaster (MlirTypeID mlirTypeID,
9298 nb::callable valueCaster, bool replace) {
99+ nb::ft_lock_guard lock (mutex);
93100 nb::object &found = valueCasterMap[mlirTypeID];
94101 if (found && !replace)
95102 throw std::runtime_error (" Value caster is already registered: " +
@@ -99,6 +106,7 @@ void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
99106
100107void PyGlobals::registerDialectImpl (const std::string &dialectNamespace,
101108 nb::object pyClass) {
109+ nb::ft_lock_guard lock (mutex);
102110 nb::object &found = dialectClassMap[dialectNamespace];
103111 if (found) {
104112 throw std::runtime_error ((llvm::Twine (" Dialect namespace '" ) +
@@ -110,6 +118,7 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
110118
111119void PyGlobals::registerOperationImpl (const std::string &operationName,
112120 nb::object pyClass, bool replace) {
121+ nb::ft_lock_guard lock (mutex);
113122 nb::object &found = operationClassMap[operationName];
114123 if (found && !replace) {
115124 throw std::runtime_error ((llvm::Twine (" Operation '" ) + operationName +
@@ -121,6 +130,7 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
121130
122131std::optional<nb::callable>
123132PyGlobals::lookupAttributeBuilder (const std::string &attributeKind) {
133+ nb::ft_lock_guard lock (mutex);
124134 const auto foundIt = attributeBuilderMap.find (attributeKind);
125135 if (foundIt != attributeBuilderMap.end ()) {
126136 assert (foundIt->second && " attribute builder is defined" );
@@ -133,6 +143,7 @@ std::optional<nb::callable> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
133143 MlirDialect dialect) {
134144 // Try to load dialect module.
135145 (void )loadDialectModule (unwrap (mlirDialectGetNamespace (dialect)));
146+ nb::ft_lock_guard lock (mutex);
136147 const auto foundIt = typeCasterMap.find (mlirTypeID);
137148 if (foundIt != typeCasterMap.end ()) {
138149 assert (foundIt->second && " type caster is defined" );
@@ -145,6 +156,7 @@ std::optional<nb::callable> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
145156 MlirDialect dialect) {
146157 // Try to load dialect module.
147158 (void )loadDialectModule (unwrap (mlirDialectGetNamespace (dialect)));
159+ nb::ft_lock_guard lock (mutex);
148160 const auto foundIt = valueCasterMap.find (mlirTypeID);
149161 if (foundIt != valueCasterMap.end ()) {
150162 assert (foundIt->second && " value caster is defined" );
@@ -158,6 +170,7 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
158170 // Make sure dialect module is loaded.
159171 if (!loadDialectModule (dialectNamespace))
160172 return std::nullopt ;
173+ nb::ft_lock_guard lock (mutex);
161174 const auto foundIt = dialectClassMap.find (dialectNamespace);
162175 if (foundIt != dialectClassMap.end ()) {
163176 assert (foundIt->second && " dialect class is defined" );
@@ -175,6 +188,7 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
175188 if (!loadDialectModule (dialectNamespace))
176189 return std::nullopt ;
177190
191+ nb::ft_lock_guard lock (mutex);
178192 auto foundIt = operationClassMap.find (operationName);
179193 if (foundIt != operationClassMap.end ()) {
180194 assert (foundIt->second && " OpView is defined" );
0 commit comments