Skip to content

Commit 911b6a5

Browse files
authored
Fix builtin addressof handling and reference type initexpr (#304)
* Fix builtin addressof handling and reference type initexpr * clang-format * Fix warning print
1 parent f424286 commit 911b6a5

File tree

3 files changed

+71
-30
lines changed

3 files changed

+71
-30
lines changed

tools/cgeist/Lib/CGCall.cc

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "clang-mlir.h"
1111
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1212
#include "utils.h"
13+
#include "clang/Basic/Builtins.h"
1314

1415
#define DEBUG_TYPE "CGCall"
1516

@@ -367,6 +368,8 @@ ValueCategory MLIRScanner::CallHelper(
367368

368369
std::pair<ValueCategory, bool>
369370
MLIRScanner::EmitClangBuiltinCallExpr(clang::CallExpr *expr) {
371+
auto loc = getMLIRLocation(expr->getExprLoc());
372+
370373
switch (expr->getBuiltinCallee()) {
371374
case clang::Builtin::BImove:
372375
case clang::Builtin::BImove_if_noexcept:
@@ -375,6 +378,34 @@ MLIRScanner::EmitClangBuiltinCallExpr(clang::CallExpr *expr) {
375378
auto V = Visit(expr->getArg(0));
376379
return make_pair(V, true);
377380
}
381+
case clang::Builtin::BIaddressof:
382+
case clang::Builtin::BI__addressof:
383+
case clang::Builtin::BI__builtin_addressof: {
384+
auto V = Visit(expr->getArg(0));
385+
assert(V.isReference);
386+
mlir::Value val = V.val;
387+
auto T = getMLIRType(expr->getType());
388+
if (T == val.getType())
389+
return make_pair(ValueCategory(val, /*isRef*/ false), true);
390+
if (T.isa<LLVM::LLVMPointerType>()) {
391+
if (val.getType().isa<MemRefType>())
392+
val = builder.create<polygeist::Memref2PointerOp>(loc, T, val);
393+
else if (T != val.getType())
394+
val = builder.create<LLVM::BitcastOp>(loc, T, val);
395+
return make_pair(ValueCategory(val, /*isRef*/ false), true);
396+
} else {
397+
assert(T.isa<MemRefType>());
398+
if (val.getType().isa<MemRefType>())
399+
val = builder.create<polygeist::Memref2PointerOp>(
400+
loc, LLVM::LLVMPointerType::get(builder.getI8Type()), val);
401+
if (val.getType().isa<LLVM::LLVMPointerType>())
402+
val = builder.create<polygeist::Pointer2MemrefOp>(loc, T, val);
403+
return make_pair(ValueCategory(val, /*isRef*/ false), true);
404+
}
405+
expr->dump();
406+
llvm::errs() << " val: " << val << " T: " << T << "\n";
407+
assert(0 && "unhandled builtin addressof");
408+
}
378409
default:
379410
break;
380411
}
@@ -591,36 +622,6 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
591622
return Visit(expr->getArg(0));
592623
}
593624
}
594-
if (auto *ic = dyn_cast<ImplicitCastExpr>(expr->getCallee()))
595-
if (auto *sr = dyn_cast<DeclRefExpr>(ic->getSubExpr())) {
596-
if (sr->getDecl()->getIdentifier() &&
597-
sr->getDecl()->getName() == "__builtin_addressof") {
598-
auto V = Visit(expr->getArg(0));
599-
assert(V.isReference);
600-
mlir::Value val = V.val;
601-
auto T = getMLIRType(expr->getType());
602-
if (T == val.getType())
603-
return ValueCategory(val, /*isRef*/ false);
604-
if (T.isa<LLVM::LLVMPointerType>()) {
605-
if (val.getType().isa<MemRefType>())
606-
val = builder.create<polygeist::Memref2PointerOp>(loc, T, val);
607-
else if (T != val.getType())
608-
val = builder.create<LLVM::BitcastOp>(loc, T, val);
609-
return ValueCategory(val, /*isRef*/ false);
610-
} else {
611-
assert(T.isa<MemRefType>());
612-
if (val.getType().isa<MemRefType>())
613-
val = builder.create<polygeist::Memref2PointerOp>(
614-
loc, LLVM::LLVMPointerType::get(builder.getI8Type()), val);
615-
if (val.getType().isa<LLVM::LLVMPointerType>())
616-
val = builder.create<polygeist::Pointer2MemrefOp>(loc, T, val);
617-
return ValueCategory(val, /*isRef*/ false);
618-
}
619-
expr->dump();
620-
llvm::errs() << " val: " << val << " T: " << T << "\n";
621-
assert(0 && "unhandled builtin addressof");
622-
}
623-
}
624625
if (auto *ic = dyn_cast<ImplicitCastExpr>(expr->getCallee()))
625626
if (auto *sr = dyn_cast<DeclRefExpr>(ic->getSubExpr())) {
626627
if (sr->getDecl()->getIdentifier() &&
@@ -1406,6 +1407,11 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
14061407
}
14071408
#endif
14081409

1410+
if (auto BI = expr->getBuiltinCallee())
1411+
if (!Glob.CGM.getContext().BuiltinInfo.isPredefinedLibFunction(BI))
1412+
llvm::errs() << "warning: we failed to emit call to builtin function "
1413+
<< Glob.CGM.getContext().BuiltinInfo.getName(BI) << "\n";
1414+
14091415
const auto *callee = EmitCallee(expr->getCallee());
14101416

14111417
std::set<std::string> funcs = {

tools/cgeist/Lib/clang-mlir.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,10 @@ void MLIRScanner::init(mlir::func::FuncOp function, const FunctionDecl *fd) {
370370
expr->getInit()->dump();
371371
assert(initexpr.val);
372372
}
373+
if (field->getType()->isReferenceType()) {
374+
assert(initexpr.isReference);
375+
initexpr.isReference = false;
376+
}
373377
bool isArray = false;
374378
Glob.getMLIRType(expr->getInit()->getType(), &isArray);
375379

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// RUN: cgeist %s %stdinclude --function=* -S | FileCheck %s
2+
3+
#include <memory>
4+
5+
struct Ptr {
6+
};
7+
8+
Ptr *foo()
9+
{
10+
Ptr p;
11+
return std::addressof(p); // calls Ptr<int>* overload, (= this)
12+
}
13+
14+
Ptr *bar()
15+
{
16+
Ptr p;
17+
return __builtin_addressof(p); // calls Ptr<int>* overload, (= this)
18+
}
19+
20+
// CHECK-LABEL: func.func @_Z3foov() -> memref<?x!llvm.struct<(i8)>> attributes {llvm.linkage = #llvm.linkage<external>} {
21+
// CHECK: %[[VAL_0:.*]] = memref.alloca() : memref<1x!llvm.struct<(i8)>>
22+
// CHECK: %[[VAL_1:.*]] = memref.cast %[[VAL_0]] : memref<1x!llvm.struct<(i8)>> to memref<?x!llvm.struct<(i8)>>
23+
// CHECK: return %[[VAL_1]] : memref<?x!llvm.struct<(i8)>>
24+
// CHECK: }
25+
26+
// CHECK-LABEL: func.func @_Z3barv() -> memref<?x!llvm.struct<(i8)>> attributes {llvm.linkage = #llvm.linkage<external>} {
27+
// CHECK: %[[VAL_0:.*]] = memref.alloca() : memref<1x!llvm.struct<(i8)>>
28+
// CHECK: %[[VAL_1:.*]] = memref.cast %[[VAL_0]] : memref<1x!llvm.struct<(i8)>> to memref<?x!llvm.struct<(i8)>>
29+
// CHECK: return %[[VAL_1]] : memref<?x!llvm.struct<(i8)>>
30+
// CHECK: }
31+

0 commit comments

Comments
 (0)