diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h index 30c190e50a4f7..97ae14aa0d6af 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -142,21 +142,20 @@ class ModuleTranslation { auto result = unresolvedBlockAddressMapping.try_emplace(op, cst); (void)result; assert(result.second && - "attempting to map a blockaddress that is already mapped"); + "attempting to map a blockaddress operation that is already mapped"); } - /// Maps a blockaddress operation to its corresponding placeholder LLVM - /// value. - void mapBlockTag(BlockAddressAttr attr, BlockTagOp blockTag) { - // Attempts to map already mapped block labels which is fine if the given - // labels are verified to be unique. - blockTagMapping[attr] = blockTag; + /// Maps a BlockAddressAttr to its corresponding LLVM basic block. + void mapBlockAddress(BlockAddressAttr attr, llvm::BasicBlock *block) { + auto result = blockAddressToLLVMMapping.try_emplace(attr, block); + (void)result; + assert(result.second && + "attempting to map a blockaddress attribute that is already mapped"); } - /// Finds an MLIR block that corresponds to the given MLIR call - /// operation. - BlockTagOp lookupBlockTag(BlockAddressAttr attr) const { - return blockTagMapping.lookup(attr); + /// Finds the LLVM basic block that corresponds to the given BlockAddressAttr. + llvm::BasicBlock *lookupBlockAddress(BlockAddressAttr attr) const { + return blockAddressToLLVMMapping.lookup(attr); } /// Removes the mapping for blocks contained in the region and values defined @@ -463,10 +462,9 @@ class ModuleTranslation { /// mapping is used to replace the placeholders with the LLVM block addresses. DenseMap unresolvedBlockAddressMapping; - /// Mapping from a BlockAddressAttr attribute to a matching BlockTagOp. This - /// is used to cache BlockTagOp locations instead of walking a LLVMFuncOp in - /// search for those. - DenseMap blockTagMapping; + /// Mapping from a BlockAddressAttr attribute to it's matching LLVM basic + /// block. + DenseMap blockAddressToLLVMMapping; /// Stack of user-specified state elements, useful when translating operations /// with regions. diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index 4ea313019f34d..9470b54c9f3aa 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -690,19 +690,13 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, // Emit blockaddress. We first need to find the LLVM block referenced by this // operation and then create a LLVM block address for it. if (auto blockAddressOp = dyn_cast(opInst)) { - // getBlockTagOp() walks a function to search for block labels. Check - // whether it's in cache first. BlockAddressAttr blockAddressAttr = blockAddressOp.getBlockAddr(); - BlockTagOp blockTagOp = moduleTranslation.lookupBlockTag(blockAddressAttr); - if (!blockTagOp) { - blockTagOp = blockAddressOp.getBlockTagOp(); - moduleTranslation.mapBlockTag(blockAddressAttr, blockTagOp); - } + llvm::BasicBlock *llvmBlock = + moduleTranslation.lookupBlockAddress(blockAddressAttr); llvm::Value *llvmValue = nullptr; StringRef fnName = blockAddressAttr.getFunction().getValue(); - if (llvm::BasicBlock *llvmBlock = - moduleTranslation.lookupBlock(blockTagOp->getBlock())) { + if (llvmBlock) { llvm::Function *llvmFn = moduleTranslation.lookupFunction(fnName); llvmValue = llvm::BlockAddress::get(llvmFn, llvmBlock); } else { @@ -736,7 +730,8 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, FlatSymbolRefAttr::get(&moduleTranslation.getContext(), funcOp.getName()), blockTagOp.getTag()); - moduleTranslation.mapBlockTag(blockAddressAttr, blockTagOp); + moduleTranslation.mapBlockAddress(blockAddressAttr, + builder.GetInsertBlock()); return success(); } diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 1168b9f339904..95b8ee0331c55 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -1843,17 +1843,13 @@ LogicalResult ModuleTranslation::convertComdats() { LogicalResult ModuleTranslation::convertUnresolvedBlockAddress() { for (auto &[blockAddressOp, llvmCst] : unresolvedBlockAddressMapping) { BlockAddressAttr blockAddressAttr = blockAddressOp.getBlockAddr(); - BlockTagOp blockTagOp = lookupBlockTag(blockAddressAttr); - assert(blockTagOp && "expected all block tags to be already seen"); - - llvm::BasicBlock *llvmBlock = lookupBlock(blockTagOp->getBlock()); + llvm::BasicBlock *llvmBlock = lookupBlockAddress(blockAddressAttr); assert(llvmBlock && "expected LLVM blocks to be already translated"); // Update mapping with new block address constant. auto *llvmBlockAddr = llvm::BlockAddress::get( lookupFunction(blockAddressAttr.getFunction().getValue()), llvmBlock); llvmCst->replaceAllUsesWith(llvmBlockAddr); - mapValue(blockAddressOp.getResult(), llvmBlockAddr); assert(llvmCst->use_empty() && "expected all uses to be replaced"); cast(llvmCst)->eraseFromParent(); } diff --git a/mlir/test/Target/LLVMIR/blockaddress.mlir b/mlir/test/Target/LLVMIR/blockaddress.mlir index fb3d853531122..4473f91c4bdb5 100644 --- a/mlir/test/Target/LLVMIR/blockaddress.mlir +++ b/mlir/test/Target/LLVMIR/blockaddress.mlir @@ -34,3 +34,32 @@ llvm.func @blockaddr0() -> !llvm.ptr { // CHECK: [[RET]]: // CHECK: ret ptr blockaddress(@blockaddr0, %1) // CHECK: } + +// ----- + +llvm.mlir.global private @h() {addr_space = 0 : i32, dso_local} : !llvm.ptr { + %0 = llvm.blockaddress > : !llvm.ptr + llvm.return %0 : !llvm.ptr +} + +// CHECK: @h = private global ptr blockaddress(@h3, %[[BB_ADDR:.*]]) + +// CHECK: define void @h3() { +// CHECK: br label %[[BB_ADDR]] + +// CHECK: [[BB_ADDR]]: +// CHECK: ret void +// CHECK: } + +// CHECK: define void @h0() + +llvm.func @h3() { + llvm.br ^bb1 +^bb1: + llvm.blocktag + llvm.return +} + +llvm.func @h0() { + llvm.return +}