diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp index 430502d85dfb4..cdb0f559d78b4 100644 --- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp @@ -128,13 +128,18 @@ static std::optional findOneNVVMAnnotation(const GlobalValue *gv, auto &AC = getAnnotationCache(); std::lock_guard Guard(AC.Lock); const Module *m = gv->getParent(); - if (AC.Cache.find(m) == AC.Cache.end()) + auto ACIt = AC.Cache.find(m); + if (ACIt == AC.Cache.end()) cacheAnnotationFromMD(m, gv); - else if (AC.Cache[m].find(gv) == AC.Cache[m].end()) + else if (ACIt->second.find(gv) == ACIt->second.end()) cacheAnnotationFromMD(m, gv); - if (AC.Cache[m][gv].find(prop) == AC.Cache[m][gv].end()) + // Look up AC.Cache[m][gv] again because cacheAnnotationFromMD may have + // inserted the entry. + auto &KVP = AC.Cache[m][gv]; + auto It = KVP.find(prop); + if (It == KVP.end()) return std::nullopt; - return AC.Cache[m][gv][prop][0]; + return It->second[0]; } static bool findAllNVVMAnnotation(const GlobalValue *gv, @@ -143,13 +148,18 @@ static bool findAllNVVMAnnotation(const GlobalValue *gv, auto &AC = getAnnotationCache(); std::lock_guard Guard(AC.Lock); const Module *m = gv->getParent(); - if (AC.Cache.find(m) == AC.Cache.end()) + auto ACIt = AC.Cache.find(m); + if (ACIt == AC.Cache.end()) cacheAnnotationFromMD(m, gv); - else if (AC.Cache[m].find(gv) == AC.Cache[m].end()) + else if (ACIt->second.find(gv) == ACIt->second.end()) cacheAnnotationFromMD(m, gv); - if (AC.Cache[m][gv].find(prop) == AC.Cache[m][gv].end()) + // Look up AC.Cache[m][gv] again because cacheAnnotationFromMD may have + // inserted the entry. + auto &KVP = AC.Cache[m][gv]; + auto It = KVP.find(prop); + if (It == KVP.end()) return false; - retval = AC.Cache[m][gv][prop]; + retval = It->second; return true; }