Skip to content

Commit 62337aa

Browse files
authored
Use funcop for indirect call (#274)
* Use funcop for indirect call * Add test * Add comment
1 parent 7fbcef9 commit 62337aa

File tree

5 files changed

+94
-6
lines changed

5 files changed

+94
-6
lines changed

include/polygeist/PolygeistOps.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
include "Dialect.td"
1313
include "mlir/Interfaces/SideEffectInterfaces.td"
1414
include "mlir/Interfaces/ViewLikeInterface.td"
15+
include "mlir/IR/SymbolInterfaces.td"
1516

1617
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
1718
include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td"
@@ -111,6 +112,14 @@ def Pointer2MemrefOp : Polygeist_Op<"pointer2memref", [
111112
}];
112113
}
113114

115+
def GetFuncOp : Polygeist_Op<"get_func",
116+
[NoSideEffect, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
117+
let summary = "get the pointer pointing to a function";
118+
let arguments = (ins FlatSymbolRefAttr:$name);
119+
let results = (outs LLVM_AnyPointer : $result);
120+
let assemblyFormat = "$name `:` type($result) attr-dict";
121+
}
122+
114123
def TrivialUseOp : Polygeist_Op<"trivialuse"> {
115124
let summary = "memref subview operation";
116125

lib/polygeist/Ops.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5005,3 +5005,19 @@ void TypeAlignOp::getCanonicalizationPatterns(RewritePatternSet &results,
50055005
// RankReduction<memref::AllocaOp, scf::ParallelOp>,
50065006
AggressiveAllocaScopeInliner, InductiveVarRemoval>(context);
50075007
}
5008+
5009+
//===----------------------------------------------------------------------===//
5010+
// GetFuncOp
5011+
//===----------------------------------------------------------------------===//
5012+
5013+
LogicalResult GetFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
5014+
// TODO: Verify that the result type is same as the type of the referenced
5015+
// func.func op.
5016+
auto global =
5017+
symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, getNameAttr());
5018+
if (!global)
5019+
return emitOpError("'")
5020+
<< getName() << "' does not reference a valid global funcOp";
5021+
5022+
return success();
5023+
}

lib/polygeist/Passes/ConvertPolygeistToLLVM.cpp

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,29 @@ struct GlobalOpTypeConversion : public OpConversionPattern<LLVM::GlobalOp> {
589589
}
590590
};
591591

592+
struct GetFuncOpConversion : public OpConversionPattern<polygeist::GetFuncOp> {
593+
explicit GetFuncOpConversion(LLVMTypeConverter &converter)
594+
: OpConversionPattern<polygeist::GetFuncOp>(converter,
595+
&converter.getContext()) {}
596+
597+
LogicalResult
598+
matchAndRewrite(polygeist::GetFuncOp op,
599+
polygeist::GetFuncOp::Adaptor adaptor,
600+
ConversionPatternRewriter &rewriter) const override {
601+
TypeConverter *converter = getTypeConverter();
602+
Type retType = op.getType();
603+
604+
Type convertedType = converter->convertType(retType);
605+
if (!convertedType)
606+
return failure();
607+
608+
rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, convertedType,
609+
op.getName());
610+
611+
return success();
612+
}
613+
};
614+
592615
struct ReturnOpTypeConversion : public ConvertOpToLLVMPattern<LLVM::ReturnOp> {
593616
using ConvertOpToLLVMPattern<LLVM::ReturnOp>::ConvertOpToLLVMPattern;
594617

@@ -641,9 +664,8 @@ struct ConvertPolygeistToLLVMPass
641664

642665
converter.addConversion([&](async::TokenType type) { return type; });
643666

644-
patterns
645-
.add<LLVMOpLowering, GlobalOpTypeConversion, ReturnOpTypeConversion>(
646-
converter);
667+
patterns.add<LLVMOpLowering, GlobalOpTypeConversion,
668+
ReturnOpTypeConversion, GetFuncOpConversion>(converter);
647669
patterns.add<URLLVMOpLowering>(converter);
648670

649671
// Legality callback for operations that checks whether their operand and

tools/cgeist/Lib/clang-mlir.cc

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3098,10 +3098,19 @@ ValueCategory MLIRScanner::VisitDeclRefExpr(DeclRefExpr *E) {
30983098
auto loc = getMLIRLocation(E->getLocation());
30993099
auto name = E->getDecl()->getName().str();
31003100

3101-
if (auto tocall = dyn_cast<FunctionDecl>(E->getDecl()))
3102-
return ValueCategory(builder.create<LLVM::AddressOfOp>(
3103-
loc, Glob.GetOrCreateLLVMFunction(tocall)),
3101+
if (auto tocall = dyn_cast<FunctionDecl>(E->getDecl())) {
3102+
auto f = Glob.GetOrCreateMLIRFunction(tocall);
3103+
auto FT = f.getFunctionType();
3104+
mlir::Type RT = LLVM::LLVMVoidType::get(f.getContext());
3105+
if (FT.getNumResults() != 0)
3106+
RT = FT.getResult(0);
3107+
LLVM::LLVMFunctionType LFT = LLVM::LLVMFunctionType::get(
3108+
RT, FT.getInputs(), /*unsupported presentlyFT.isVariadic()*/ false);
3109+
3110+
return ValueCategory(builder.create<polygeist::GetFuncOp>(
3111+
loc, LLVM::LLVMPointerType::get(LFT), f.getName()),
31043112
/*isReference*/ true);
3113+
}
31053114

31063115
if (auto VD = dyn_cast<VarDecl>(E->getDecl())) {
31073116
if (Captures.find(VD) != Captures.end()) {
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// RUN: cgeist %s --function=main -S | FileCheck %s
2+
// RUN: cgeist %s --function=main -S --emit-llvm | FileCheck %s --check-prefix=LLCHECK
3+
4+
int square(int x) {
5+
return x*x;
6+
}
7+
int meta(int (*f)(int), int x) {
8+
return f(x);
9+
}
10+
11+
int printf(const char*, ...);
12+
int main() {
13+
printf("sq(%d)=%d\n", 3, meta(square, 3));
14+
return 0;
15+
}
16+
17+
// CHECK: func.func @main() -> i32 attributes
18+
// CHECK-DAG: %c3_i32 = arith.constant 3 : i32
19+
// CHECK-DAG: %c0_i32 = arith.constant 0 : i32
20+
// CHECK-NEXT: %0 = llvm.mlir.addressof @str0 : !llvm.ptr<array<11 x i8>>
21+
// CHECK-NEXT: %1 = llvm.getelementptr %0[0, 0] : (!llvm.ptr<array<11 x i8>>) -> !llvm.ptr<i8>
22+
// CHECK-NEXT: %2 = polygeist.get_func @square : !llvm.ptr<func<i32 (i32)>>
23+
// CHECK-NEXT: %3 = call @meta(%2, %c3_i32) : (!llvm.ptr<func<i32 (i32)>>, i32) -> i32
24+
// CHECK-NEXT: %4 = llvm.call @printf(%1, %c3_i32, %3) : (!llvm.ptr<i8>, i32, i32) -> i32
25+
// CHECK-NEXT: return %c0_i32 : i32
26+
// CHECK-NEXT: }
27+
28+
// LLCHECK: define i32 @main()
29+
// LLCHECK: %1 = call i32 @meta(i32 (i32)* @square, i32 3)
30+
// LLCHECK: %2 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([11 x i8], [11 x i8]* @str0, i32 0, i32 0), i32 3, i32 %1)
31+
// LLCHECK: ret i32 0
32+
// LLCHECK: }

0 commit comments

Comments
 (0)