Skip to content

Commit 5f1b14f

Browse files
authored
Fix llvm struct abi field lowering (#198)
* Fix lookup where no base is created in llvm abi * Add test * Bump LLVM * Fix API change * Fix union * Add test
1 parent f51ff0d commit 5f1b14f

File tree

6 files changed

+199
-34
lines changed

6 files changed

+199
-34
lines changed

lib/polygeist/Passes/ConvertPolygeistToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ struct LLVMOpLowering : public ConversionPattern {
293293
for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i)
294294
state.addRegion();
295295

296-
Operation *rewritten = rewriter.createOperation(state);
296+
Operation *rewritten = rewriter.create(state);
297297
rewriter.replaceOp(op, rewritten->getResults());
298298

299299
for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i)

llvm-project

Submodule llvm-project updated 2152 files

tools/mlir-clang/Lib/clang-mlir.cc

Lines changed: 79 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -256,12 +256,6 @@ void MLIRScanner::init(mlir::FuncOp function, const FunctionDecl *fd) {
256256
llvm::errs() << " warning, destructor not fully handled yet\n";
257257
}
258258

259-
Stmt *stmt = fd->getBody();
260-
assert(stmt);
261-
if (ShowAST) {
262-
stmt->dump();
263-
}
264-
265259
auto i1Ty = builder.getIntegerType(1);
266260
auto type = mlir::MemRefType::get({}, i1Ty, {}, 0);
267261
auto truev = builder.create<ConstantIntOp>(loc, true, 1);
@@ -279,6 +273,51 @@ void MLIRScanner::init(mlir::FuncOp function, const FunctionDecl *fd) {
279273
returnVal, std::vector<mlir::Value>({}));
280274
}
281275
}
276+
277+
if (auto D = dyn_cast<CXXMethodDecl>(fd)) {
278+
// ClangAST incorrectly does not contain the correct definition
279+
// of a union move operation and as such we _must_ emit a memcpy
280+
// for a defaulted union copy or move.
281+
if (D->getParent()->isUnion() && D->isDefaulted()) {
282+
mlir::Value V = ThisVal.val;
283+
assert(V);
284+
if (auto MT = V.getType().dyn_cast<MemRefType>()) {
285+
V = builder.create<polygeist::Pointer2MemrefOp>(
286+
loc, LLVM::LLVMPointerType::get(MT.getElementType()), V);
287+
}
288+
mlir::Value src = function.getArgument(1);
289+
if (auto MT = src.getType().dyn_cast<MemRefType>()) {
290+
src = builder.create<polygeist::Pointer2MemrefOp>(
291+
loc, LLVM::LLVMPointerType::get(MT.getElementType()), src);
292+
}
293+
mlir::Value typeSize = builder.create<polygeist::TypeSizeOp>(
294+
loc, builder.getIndexType(),
295+
mlir::TypeAttr::get(
296+
V.getType().cast<LLVM::LLVMPointerType>().getElementType()));
297+
typeSize = builder.create<arith::IndexCastOp>(loc, builder.getI64Type(),
298+
typeSize);
299+
V = builder.create<LLVM::BitcastOp>(
300+
loc,
301+
LLVM::LLVMPointerType::get(
302+
builder.getI8Type(),
303+
V.getType().cast<LLVM::LLVMPointerType>().getAddressSpace()),
304+
V);
305+
src = builder.create<LLVM::BitcastOp>(
306+
loc,
307+
LLVM::LLVMPointerType::get(
308+
builder.getI8Type(),
309+
src.getType().cast<LLVM::LLVMPointerType>().getAddressSpace()),
310+
src);
311+
mlir::Value volatileCpy = builder.create<ConstantIntOp>(loc, false, 1);
312+
builder.create<LLVM::MemcpyOp>(loc, V, src, typeSize, volatileCpy);
313+
}
314+
}
315+
316+
Stmt *stmt = fd->getBody();
317+
assert(stmt);
318+
if (ShowAST) {
319+
stmt->dump();
320+
}
282321
Visit(stmt);
283322

284323
if (function.getFunctionType().getResults().size()) {
@@ -3206,13 +3245,19 @@ mlir::Value MLIRScanner::GetAddressOfBaseClass(
32063245
Glob.CGM.getContext().getLValueReferenceType(QualType(BaseType, 0)));
32073246

32083247
size_t fnum;
3248+
bool subIndex = true;
32093249

32103250
if (isLLVMStructABI(RD, /*ST*/ nullptr)) {
32113251
auto &layout = Glob.CGM.getTypes().getCGRecordLayout(RD);
32123252
if (std::get<1>(tup))
32133253
fnum = layout.getVirtualBaseIndex(BaseDecl);
3214-
else
3215-
fnum = layout.getNonVirtualBaseLLVMFieldNo(BaseDecl);
3254+
else {
3255+
if (!layout.hasNonVirtualBaseLLVMField(BaseDecl)) {
3256+
subIndex = false;
3257+
} else {
3258+
fnum = layout.getNonVirtualBaseLLVMFieldNo(BaseDecl);
3259+
}
3260+
}
32163261
} else {
32173262
assert(!std::get<1>(tup) && "Should not see virtual bases here!");
32183263
fnum = 0;
@@ -3228,31 +3273,34 @@ mlir::Value MLIRScanner::GetAddressOfBaseClass(
32283273
assert(found);
32293274
}
32303275

3231-
if (auto mt = value.getType().dyn_cast<MemRefType>()) {
3232-
auto shape = std::vector<int64_t>(mt.getShape());
3233-
shape.erase(shape.begin());
3234-
auto mt0 = mlir::MemRefType::get(shape, mt.getElementType(),
3235-
MemRefLayoutAttrInterface(),
3236-
mt.getMemorySpace());
3237-
value = builder.create<polygeist::SubIndexOp>(loc, mt0, value,
3238-
getConstantIndex(fnum));
3239-
} else {
3240-
mlir::Value idx[] = {builder.create<arith::ConstantIntOp>(loc, 0, 32),
3241-
builder.create<arith::ConstantIntOp>(loc, fnum, 32)};
3242-
auto PT = value.getType().cast<LLVM::LLVMPointerType>();
3243-
mlir::Type ET;
3244-
if (auto ST =
3245-
PT.getElementType().dyn_cast<mlir::LLVM::LLVMStructType>()) {
3246-
ET = ST.getBody()[fnum];
3276+
if (subIndex) {
3277+
if (auto mt = value.getType().dyn_cast<MemRefType>()) {
3278+
auto shape = std::vector<int64_t>(mt.getShape());
3279+
shape.erase(shape.begin());
3280+
auto mt0 = mlir::MemRefType::get(shape, mt.getElementType(),
3281+
MemRefLayoutAttrInterface(),
3282+
mt.getMemorySpace());
3283+
value = builder.create<polygeist::SubIndexOp>(loc, mt0, value,
3284+
getConstantIndex(fnum));
32473285
} else {
3248-
ET = PT.getElementType()
3249-
.cast<mlir::LLVM::LLVMArrayType>()
3250-
.getElementType();
3251-
}
3286+
mlir::Value idx[] = {
3287+
builder.create<arith::ConstantIntOp>(loc, 0, 32),
3288+
builder.create<arith::ConstantIntOp>(loc, fnum, 32)};
3289+
auto PT = value.getType().cast<LLVM::LLVMPointerType>();
3290+
mlir::Type ET;
3291+
if (auto ST =
3292+
PT.getElementType().dyn_cast<mlir::LLVM::LLVMStructType>()) {
3293+
ET = ST.getBody()[fnum];
3294+
} else {
3295+
ET = PT.getElementType()
3296+
.cast<mlir::LLVM::LLVMArrayType>()
3297+
.getElementType();
3298+
}
32523299

3253-
value = builder.create<LLVM::GEPOp>(
3254-
loc, LLVM::LLVMPointerType::get(ET, PT.getAddressSpace()), value,
3255-
idx);
3300+
value = builder.create<LLVM::GEPOp>(
3301+
loc, LLVM::LLVMPointerType::get(ET, PT.getAddressSpace()), value,
3302+
idx);
3303+
}
32563304
}
32573305

32583306
auto pt = nt.dyn_cast<mlir::LLVM::LLVMPointerType>();

tools/mlir-clang/Lib/utils.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,5 +54,5 @@ Operation *mlirclang::replaceFuncByOperation(
5454
// NOTE: The attributes of the provided FuncOp is ignored.
5555
OperationState opState(b.getUnknownLoc(), opName, input,
5656
f.getCallableResults(), {});
57-
return b.createOperation(opState);
57+
return b.create(opState);
5858
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// RUN: mlir-clang %s --function=* --struct-abi=0 -memref-abi=0 -S | FileCheck %s
2+
3+
void run0(void*);
4+
void run1(void*);
5+
void run2(void*);
6+
7+
class M {
8+
public:
9+
M() { run0(this); }
10+
};
11+
12+
13+
struct _Alloc_hider : M
14+
{
15+
_Alloc_hider() { run1(this); }
16+
17+
};
18+
19+
class basic_ostringstream
20+
{
21+
public:
22+
_Alloc_hider _M_stringbuf;
23+
basic_ostringstream() { run2(this); }
24+
};
25+
26+
void a() {
27+
::basic_ostringstream a;
28+
}
29+
30+
// CHECK: func @_Z1av() attributes {llvm.linkage = #llvm.linkage<external>} {
31+
// CHECK-NEXT: %c1_i64 = arith.constant 1 : i64
32+
// CHECK-NEXT: %0 = llvm.alloca %c1_i64 x !llvm.struct<(struct<(i8)>)> : (i64) -> !llvm.ptr<struct<(struct<(i8)>)>>
33+
// CHECK-NEXT: call @_ZN19basic_ostringstreamC1Ev(%0) : (!llvm.ptr<struct<(struct<(i8)>)>>) -> ()
34+
// CHECK-NEXT: return
35+
// CHECK-NEXT: }
36+
// CHECK: func @_ZN19basic_ostringstreamC1Ev(%arg0: !llvm.ptr<struct<(struct<(i8)>)>>) attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
37+
// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32
38+
// CHECK-NEXT: %0 = llvm.getelementptr %arg0[%c0_i32, 0] : (!llvm.ptr<struct<(struct<(i8)>)>>, i32) -> !llvm.ptr<struct<(i8)>>
39+
// CHECK-NEXT: call @_ZN12_Alloc_hiderC1Ev(%0) : (!llvm.ptr<struct<(i8)>>) -> ()
40+
// CHECK-NEXT: %1 = llvm.bitcast %arg0 : !llvm.ptr<struct<(struct<(i8)>)>> to !llvm.ptr<i8>
41+
// CHECK-NEXT: call @_Z4run2Pv(%1) : (!llvm.ptr<i8>) -> ()
42+
// CHECK-NEXT: return
43+
// CHECK-NEXT: }
44+
// CHECK: func @_ZN12_Alloc_hiderC1Ev(%arg0: !llvm.ptr<struct<(i8)>>) attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
45+
// CHECK-NEXT: call @_ZN1MC1Ev(%arg0) : (!llvm.ptr<struct<(i8)>>) -> ()
46+
// CHECK-NEXT: %0 = llvm.bitcast %arg0 : !llvm.ptr<struct<(i8)>> to !llvm.ptr<i8>
47+
// CHECK-NEXT: call @_Z4run1Pv(%0) : (!llvm.ptr<i8>) -> ()
48+
// CHECK-NEXT: return
49+
// CHECK-NEXT: }
50+
// CHECK: func private @_Z4run2Pv(!llvm.ptr<i8>) attributes {llvm.linkage = #llvm.linkage<external>}
51+
// CHECK-NEXT: func @_ZN1MC1Ev(%arg0: !llvm.ptr<struct<(i8)>>) attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
52+
// CHECK-NEXT: %0 = llvm.bitcast %arg0 : !llvm.ptr<struct<(i8)>> to !llvm.ptr<i8>
53+
// CHECK-NEXT: call @_Z4run0Pv(%0) : (!llvm.ptr<i8>) -> ()
54+
// CHECK-NEXT: return
55+
// CHECK-NEXT: }
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// RUN: mlir-clang %s --function=* -S | FileCheck %s
2+
3+
union S {
4+
double d;
5+
};
6+
7+
class MyScalar {
8+
public:
9+
S v;
10+
MyScalar(double vv) {
11+
v.d = vv;
12+
}
13+
};
14+
15+
void use(double);
16+
void meta() {
17+
MyScalar alpha_scalar(1.0);
18+
alpha_scalar = MyScalar(3.0);
19+
use(alpha_scalar.v.d);
20+
}
21+
22+
// CHECK: func @_Z4metav() attributes {llvm.linkage = #llvm.linkage<external>} {
23+
// CHECK-DAG: %c1_i64 = arith.constant 1 : i64
24+
// CHECK-DAG: %cst = arith.constant 1.000000e+00 : f64
25+
// CHECK-DAG: %cst_0 = arith.constant 3.000000e+00 : f64
26+
// CHECK-DAG: %c0_i32 = arith.constant 0 : i32
27+
// CHECK-NEXT: %0 = llvm.alloca %c1_i64 x !llvm.struct<(struct<(f64)>)> : (i64) -> !llvm.ptr<struct<(struct<(f64)>)>>
28+
// CHECK-NEXT: %1 = llvm.alloca %c1_i64 x !llvm.struct<(struct<(f64)>)> : (i64) -> !llvm.ptr<struct<(struct<(f64)>)>>
29+
// CHECK-NEXT: %2 = llvm.alloca %c1_i64 x !llvm.struct<(struct<(f64)>)> : (i64) -> !llvm.ptr<struct<(struct<(f64)>)>>
30+
// CHECK-NEXT: call @_ZN8MyScalarC1Ed(%2, %cst) : (!llvm.ptr<struct<(struct<(f64)>)>>, f64) -> ()
31+
// CHECK-NEXT: call @_ZN8MyScalarC1Ed(%1, %cst_0) : (!llvm.ptr<struct<(struct<(f64)>)>>, f64) -> ()
32+
// CHECK-NEXT: %3 = llvm.load %1 : !llvm.ptr<struct<(struct<(f64)>)>>
33+
// CHECK-NEXT: llvm.store %3, %0 : !llvm.ptr<struct<(struct<(f64)>)>>
34+
// CHECK-NEXT: %4 = call @_ZN8MyScalaraSEOS_(%2, %0) : (!llvm.ptr<struct<(struct<(f64)>)>>, !llvm.ptr<struct<(struct<(f64)>)>>) -> !llvm.ptr<struct<(struct<(f64)>)>>
35+
// CHECK-NEXT: %5 = llvm.getelementptr %2[%c0_i32, 0] : (!llvm.ptr<struct<(struct<(f64)>)>>, i32) -> !llvm.ptr<struct<(f64)>>
36+
// CHECK-NEXT: %6 = llvm.getelementptr %5[%c0_i32, 0] : (!llvm.ptr<struct<(f64)>>, i32) -> !llvm.ptr<f64>
37+
// CHECK-NEXT: %7 = llvm.load %6 : !llvm.ptr<f64>
38+
// CHECK-NEXT: call @_Z3used(%7) : (f64) -> ()
39+
// CHECK-NEXT: return
40+
// CHECK-NEXT: }
41+
// CHECK: func @_ZN8MyScalarC1Ed(%arg0: !llvm.ptr<struct<(struct<(f64)>)>>, %arg1: f64) attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
42+
// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32
43+
// CHECK-NEXT: %0 = llvm.getelementptr %arg0[%c0_i32, 0] : (!llvm.ptr<struct<(struct<(f64)>)>>, i32) -> !llvm.ptr<struct<(f64)>>
44+
// CHECK-NEXT: %1 = llvm.getelementptr %0[%c0_i32, 0] : (!llvm.ptr<struct<(f64)>>, i32) -> !llvm.ptr<f64>
45+
// CHECK-NEXT: llvm.store %arg1, %1 : !llvm.ptr<f64>
46+
// CHECK-NEXT: return
47+
// CHECK-NEXT: }
48+
// CHECK: func @_ZN8MyScalaraSEOS_(%arg0: !llvm.ptr<struct<(struct<(f64)>)>>, %arg1: !llvm.ptr<struct<(struct<(f64)>)>>) -> !llvm.ptr<struct<(struct<(f64)>)>> attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
49+
// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32
50+
// CHECK-NEXT: %0 = llvm.getelementptr %arg0[%c0_i32, 0] : (!llvm.ptr<struct<(struct<(f64)>)>>, i32) -> !llvm.ptr<struct<(f64)>>
51+
// CHECK-NEXT: %1 = llvm.getelementptr %arg1[%c0_i32, 0] : (!llvm.ptr<struct<(struct<(f64)>)>>, i32) -> !llvm.ptr<struct<(f64)>>
52+
// CHECK-NEXT: %2 = call @_ZN1SaSEOS_(%0, %1) : (!llvm.ptr<struct<(f64)>>, !llvm.ptr<struct<(f64)>>) -> !llvm.ptr<struct<(f64)>>
53+
// CHECK-NEXT: return %arg0 : !llvm.ptr<struct<(struct<(f64)>)>>
54+
// CHECK-NEXT: }
55+
// CHECK: func @_ZN1SaSEOS_(%arg0: !llvm.ptr<struct<(f64)>>, %arg1: !llvm.ptr<struct<(f64)>>) -> !llvm.ptr<struct<(f64)>> attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
56+
// CHECK-DAG: %c8_i64 = arith.constant 8 : i64
57+
// CHECK-DAG: %false = arith.constant false
58+
// CHECK-NEXT: %[[i0:.+]] = llvm.bitcast %arg0 : !llvm.ptr<struct<(f64)>> to !llvm.ptr<i8>
59+
// CHECK-NEXT: %[[i1:.+]] = llvm.bitcast %arg1 : !llvm.ptr<struct<(f64)>> to !llvm.ptr<i8>
60+
// CHECK-NEXT: "llvm.intr.memcpy"(%[[i0]], %[[i1]], %c8_i64, %false) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i64, i1) -> ()
61+
// CHECK-NEXT: return %arg0 : !llvm.ptr<struct<(f64)>>
62+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)