@@ -38,8 +38,11 @@ PyGlobals::PyGlobals() {
3838PyGlobals::~PyGlobals () { instance = nullptr ; }
3939
4040bool  PyGlobals::loadDialectModule (llvm::StringRef dialectNamespace) {
41-   if  (loadedDialectModules.contains (dialectNamespace))
42-     return  true ;
41+   {
42+     nb::ft_lock_guard lock (mutex);
43+     if  (loadedDialectModules.contains (dialectNamespace))
44+       return  true ;
45+   }
4346  //  Since re-entrancy is possible, make a copy of the search prefixes.
4447  std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
4548  nb::object loaded = nb::none ();
@@ -62,12 +65,14 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
6265    return  false ;
6366  //  Note: Iterator cannot be shared from prior to loading, since re-entrancy
6467  //  may have occurred, which may do anything.
68+   nb::ft_lock_guard lock (mutex);
6569  loadedDialectModules.insert (dialectNamespace);
6670  return  true ;
6771}
6872
6973void  PyGlobals::registerAttributeBuilder (const  std::string &attributeKind,
7074                                         nb::callable pyFunc, bool  replace) {
75+   nb::ft_lock_guard lock (mutex);
7176  nb::object &found = attributeBuilderMap[attributeKind];
7277  if  (found && !replace) {
7378    throw  std::runtime_error ((llvm::Twine (" Attribute builder for '"  ) +
@@ -81,6 +86,7 @@ void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
8186
8287void  PyGlobals::registerTypeCaster (MlirTypeID mlirTypeID,
8388                                   nb::callable typeCaster, bool  replace) {
89+   nb::ft_lock_guard lock (mutex);
8490  nb::object &found = typeCasterMap[mlirTypeID];
8591  if  (found && !replace)
8692    throw  std::runtime_error (" Type caster is already registered with caster: "   +
@@ -90,6 +96,7 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
9096
9197void  PyGlobals::registerValueCaster (MlirTypeID mlirTypeID,
9298                                    nb::callable valueCaster, bool  replace) {
99+   nb::ft_lock_guard lock (mutex);
93100  nb::object &found = valueCasterMap[mlirTypeID];
94101  if  (found && !replace)
95102    throw  std::runtime_error (" Value caster is already registered: "   +
@@ -99,6 +106,7 @@ void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
99106
100107void  PyGlobals::registerDialectImpl (const  std::string &dialectNamespace,
101108                                    nb::object pyClass) {
109+   nb::ft_lock_guard lock (mutex);
102110  nb::object &found = dialectClassMap[dialectNamespace];
103111  if  (found) {
104112    throw  std::runtime_error ((llvm::Twine (" Dialect namespace '"  ) +
@@ -110,6 +118,7 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
110118
111119void  PyGlobals::registerOperationImpl (const  std::string &operationName,
112120                                      nb::object pyClass, bool  replace) {
121+   nb::ft_lock_guard lock (mutex);
113122  nb::object &found = operationClassMap[operationName];
114123  if  (found && !replace) {
115124    throw  std::runtime_error ((llvm::Twine (" Operation '"  ) + operationName +
@@ -121,6 +130,7 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
121130
122131std::optional<nb::callable>
123132PyGlobals::lookupAttributeBuilder (const  std::string &attributeKind) {
133+   nb::ft_lock_guard lock (mutex);
124134  const  auto  foundIt = attributeBuilderMap.find (attributeKind);
125135  if  (foundIt != attributeBuilderMap.end ()) {
126136    assert (foundIt->second  && " attribute builder is defined"  );
@@ -133,6 +143,7 @@ std::optional<nb::callable> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
133143                                                        MlirDialect dialect) {
134144  //  Try to load dialect module.
135145  (void )loadDialectModule (unwrap (mlirDialectGetNamespace (dialect)));
146+   nb::ft_lock_guard lock (mutex);
136147  const  auto  foundIt = typeCasterMap.find (mlirTypeID);
137148  if  (foundIt != typeCasterMap.end ()) {
138149    assert (foundIt->second  && " type caster is defined"  );
@@ -145,6 +156,7 @@ std::optional<nb::callable> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
145156                                                         MlirDialect dialect) {
146157  //  Try to load dialect module.
147158  (void )loadDialectModule (unwrap (mlirDialectGetNamespace (dialect)));
159+   nb::ft_lock_guard lock (mutex);
148160  const  auto  foundIt = valueCasterMap.find (mlirTypeID);
149161  if  (foundIt != valueCasterMap.end ()) {
150162    assert (foundIt->second  && " value caster is defined"  );
@@ -158,6 +170,7 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
158170  //  Make sure dialect module is loaded.
159171  if  (!loadDialectModule (dialectNamespace))
160172    return  std::nullopt ;
173+   nb::ft_lock_guard lock (mutex);
161174  const  auto  foundIt = dialectClassMap.find (dialectNamespace);
162175  if  (foundIt != dialectClassMap.end ()) {
163176    assert (foundIt->second  && " dialect class is defined"  );
@@ -175,6 +188,7 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
175188  if  (!loadDialectModule (dialectNamespace))
176189    return  std::nullopt ;
177190
191+   nb::ft_lock_guard lock (mutex);
178192  auto  foundIt = operationClassMap.find (operationName);
179193  if  (foundIt != operationClassMap.end ()) {
180194    assert (foundIt->second  && " OpView is defined"  );
0 commit comments