Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 11 additions & 20 deletions include/llvm-dialects/Dialect/OpMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -595,34 +595,26 @@ template <typename ValueT, bool isConst> class OpMapIteratorBase final {
if (std::get<BaseIteratorT>(m_iterator) == map->m_intrinsics.end())
invalidate();
}
} else {
createFromDialectOp(desc.getMnemonic());
} else if (!createFromDialectOp(desc.getMnemonic())) {
invalidate();
}
}

OpMapIteratorBase(OpMapT *map, const llvm::Function &func) : m_map{map} {
createFromFunc(func);
if (!createFromFunc(func))
invalidate();
}

// Do a lookup for a given instruction. Mark the iterator as invalid
// if the instruction is a call-like core instruction.
// Do a lookup for a given instruction.
OpMapIteratorBase(OpMapT *map, const llvm::Instruction &inst) : m_map{map} {
if (auto *CI = llvm::dyn_cast<llvm::CallInst>(&inst)) {
const llvm::Function *callee = CI->getCalledFunction();
if (callee) {
createFromFunc(*callee);
if (callee && createFromFunc(*callee))
return;
}
}

const unsigned op = inst.getOpcode();

// Construct an invalid iterator.
if (op == llvm::Instruction::Call || op == llvm::Instruction::CallBr) {
invalidate();
return;
}

BaseIteratorT it = m_map->m_coreOpcodes.find(op);
if (it != m_map->m_coreOpcodes.end()) {
m_desc = OpDescription::fromCoreOp(op);
Expand Down Expand Up @@ -699,20 +691,20 @@ template <typename ValueT, bool isConst> class OpMapIteratorBase final {
private:
void invalidate() { m_isInvalid = true; }

void createFromFunc(const llvm::Function &func) {
bool createFromFunc(const llvm::Function &func) {
if (func.isIntrinsic()) {
m_iterator = m_map->m_intrinsics.find(func.getIntrinsicID());

if (std::get<BaseIteratorT>(m_iterator) != m_map->m_intrinsics.end()) {
m_desc = OpDescription::fromIntrinsic(func.getIntrinsicID());
return;
return true;
}
}

createFromDialectOp(func.getName());
return createFromDialectOp(func.getName());
}

void createFromDialectOp(llvm::StringRef funcName) {
bool createFromDialectOp(llvm::StringRef funcName) {
size_t idx = 0;
bool found = false;
for (auto &dialectOpKV : m_map->m_dialectOps) {
Expand All @@ -729,8 +721,7 @@ template <typename ValueT, bool isConst> class OpMapIteratorBase final {
++idx;
}

if (!found)
invalidate();
return found;
}

// Re-construct base OpDescription from the stored iterator.
Expand Down
69 changes: 66 additions & 3 deletions test/unit/interface/OpMapIRTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,13 @@ TEST_F(OpMapIRTestFixture, IntrinsicOpMatchesInstructionTest) {
EXPECT_EQ(map[AssumeDesc], "assume");

const auto &SideEffect = *B.CreateCall(
Intrinsic::getDeclaration(Mod.get(), Intrinsic::sideeffect));
Intrinsic::getOrInsertDeclaration(Mod.get(), Intrinsic::sideeffect));

const std::array<Value *, 1> AssumeArgs = {
ConstantInt::getBool(Type::getInt1Ty(Context), true)};
const auto &Assume = *B.CreateCall(
Intrinsic::getDeclaration(Mod.get(), Intrinsic::assume), AssumeArgs);
Intrinsic::getOrInsertDeclaration(Mod.get(), Intrinsic::assume),
AssumeArgs);

EXPECT_FALSE(map.lookup(SideEffect) == map.lookup(Assume));
EXPECT_EQ(map.lookup(SideEffect), "sideeffect");
Expand Down Expand Up @@ -171,7 +172,7 @@ TEST_F(OpMapIRTestFixture, MixedOpMatchesInstructionTest) {
EXPECT_EQ(map[SideEffectDesc], "sideeffect");

const auto &SideEffect = *B.CreateCall(
Intrinsic::getDeclaration(Mod.get(), Intrinsic::sideeffect));
Intrinsic::getOrInsertDeclaration(Mod.get(), Intrinsic::sideeffect));

EXPECT_EQ(map.lookup(SideEffect), "sideeffect");

Expand Down Expand Up @@ -252,3 +253,65 @@ TEST_F(OpMapIRTestFixture, DialectOpOverloadTests) {
EXPECT_EQ(map.lookup(Op1), "DialectOp4");
EXPECT_EQ(map.lookup(Op2), "DialectOp4");
}

TEST_F(OpMapIRTestFixture, CallCoreOpMatchesInstructionTest) {
OpMap<StringRef> map;
llvm_dialects::Builder B{Context};

// Define types
PointerType *PtrTy = B.getPtrTy();
IntegerType *I32Ty = Type::getInt32Ty(Context);

// Declare: %ptr @ProcOpaqueHandle(i32, %ptr)
FunctionType *ProcOpaqueHandleFuncTy =
FunctionType::get(PtrTy, {I32Ty, PtrTy}, false);
FunctionCallee ProcOpaqueHandleFunc =
Mod->getOrInsertFunction("ProcOpaqueHandle", ProcOpaqueHandleFuncTy);

B.SetInsertPoint(getEntryBlock());

// Declare %OpaqueTy = type opaque
StructType *OpaqueTy = StructType::create(Context, "OpaqueTy");

// Create a dummy global variable of type %OpaqueTy*
GlobalVariable *GV = new GlobalVariable(
*Mod, OpaqueTy, false, GlobalValue::PrivateLinkage, nullptr, "handle");
GV->setInitializer(ConstantAggregateZero::get(OpaqueTy));
Value *Op2 = GV;

// Create a constant value (e.g., 123)
Value *Op1 = B.getInt32(123);

// Build a call instruction
Value *Args[] = {Op1, Op2};
const CallInst &Call = *B.CreateCall(ProcOpaqueHandleFunc, Args);

// Create basic blocks for the function
auto *FC = getEntryBlock()->getParent();
BasicBlock *Label1BB = BasicBlock::Create(Context, "label1", FC);
BasicBlock *Label2BB = BasicBlock::Create(Context, "label2", FC);
BasicBlock *ContinueBB = BasicBlock::Create(Context, "continue", FC);

// Simulate a function that can branch to multiple labels
// For demonstration purposes, we'll create a placeholder function that represents this behavior
FunctionType *BranchFuncTy = FunctionType::get(Type::getVoidTy(Context), false);
FunctionCallee BranchFunc = Mod->getOrInsertFunction("Branch", BranchFuncTy);

// Create the CallBr instruction
const CallBrInst &CallBr = *B.CreateCallBr(BranchFunc, ContinueBB, {Label1BB, Label2BB});

// Load and test OpMap with Call and CallBr

// Add Instruction::Call to OpMap
const OpDescription CallDesc = OpDescription::fromCoreOp(Instruction::Call);
map[CallDesc] = "Call";

// Add Instruction::CallBr to OpMap
const OpDescription CallBrDesc = OpDescription::fromCoreOp(Instruction::CallBr);
map[CallBrDesc] = "CallBr";

// Look up the Call and CallBr in the map and verify it finds the entries for
// Instruction::Call and Instruction::CallBr
EXPECT_EQ(map.lookup(Call), "Call");
EXPECT_EQ(map.lookup(CallBr), "CallBr");
}