Skip to content

Commit 9608d79

Browse files
committed
Correctly emit subindex for pointer offset
1 parent 253ac70 commit 9608d79

File tree

2 files changed

+33
-11
lines changed

2 files changed

+33
-11
lines changed

tools/cgeist/Lib/clang-mlir.cc

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2896,6 +2896,18 @@ ValueCategory MLIRScanner::VisitBinaryOperator(clang::BinaryOperator *BO) {
28962896
}
28972897
}
28982898
case clang::BinaryOperator::Opcode::BO_Add: {
2899+
auto emitSubindex = [&](auto mr, auto ptradd) {
2900+
auto mt = mr.getType().template dyn_cast<mlir::MemRefType>();
2901+
auto shape = std::vector<int64_t>(mt.getShape());
2902+
shape[0] = -1;
2903+
auto mt0 = mlir::MemRefType::get(shape, mt.getElementType(),
2904+
MemRefLayoutAttrInterface(),
2905+
mt.getMemorySpace());
2906+
ptradd = castToIndex(loc, ptradd);
2907+
return ValueCategory(
2908+
builder.create<polygeist::SubIndexOp>(loc, mt0, mr, ptradd),
2909+
/*isReference*/ false);
2910+
};
28992911
if (auto cty = dyn_cast<clang::ComplexType>(BO->getType())) {
29002912
mlir::Value real =
29012913
builder.create<AddFOp>(loc, getComplexPart(loc, lhs.val, 0),
@@ -2911,17 +2923,10 @@ ValueCategory MLIRScanner::VisitBinaryOperator(clang::BinaryOperator *BO) {
29112923
if (lhs_v.getType().isa<mlir::FloatType>()) {
29122924
return ValueCategory(builder.create<AddFOp>(loc, lhs_v, rhs_v),
29132925
/*isReference*/ false);
2914-
} else if (auto mt = lhs_v.getType().dyn_cast<mlir::MemRefType>()) {
2915-
auto shape = std::vector<int64_t>(mt.getShape());
2916-
shape[0] = -1;
2917-
auto mt0 = mlir::MemRefType::get(shape, mt.getElementType(),
2918-
MemRefLayoutAttrInterface(),
2919-
mt.getMemorySpace());
2920-
auto ptradd = rhs_v;
2921-
ptradd = castToIndex(loc, ptradd);
2922-
return ValueCategory(
2923-
builder.create<polygeist::SubIndexOp>(loc, mt0, lhs_v, ptradd),
2924-
/*isReference*/ false);
2926+
} else if (lhs_v.getType().isa<mlir::MemRefType>()) {
2927+
return emitSubindex(lhs_v, rhs_v);
2928+
} else if (rhs_v.getType().isa<mlir::MemRefType>()) {
2929+
return emitSubindex(rhs_v, lhs_v);
29252930
} else if (auto pt =
29262931
lhs_v.getType().dyn_cast<mlir::LLVM::LLVMPointerType>()) {
29272932
return ValueCategory(

tools/cgeist/Test/Verification/memrefaddassign.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,26 @@ float *foo(float *a) {
44
a += 32;
55
return a;
66
}
7+
float *foo1(float *a) {
8+
return a + 32;
9+
}
10+
float *foo2(float *a) {
11+
return 32 + a;
12+
}
713
// CHECK: func @_Z3fooPf(%[[arg0:.+]]: memref<?xf32>)
814
// CHECK-NEXT %[[c32:.+]] = arith.constant 32 : index
915
// CHECK-NEXT %[[V0:.+]] = "polygeist.subindex"(%[[arg0]], %[[c32]]) : (memref<?xf32>, index) -> memref<?xf32>
1016
// CHECK-NEXT return %[[V0]] : memref<?xf32>
1117
// CHECK-NEXT }
1218

19+
// CHECK: func @_Z4foo1Pf(%[[arg0:.+]]: memref<?xf32>)
20+
// CHECK-NEXT %[[c32:.+]] = arith.constant 32 : index
21+
// CHECK-NEXT %[[V0:.+]] = "polygeist.subindex"(%[[arg0]], %[[c32]]) : (memref<?xf32>, index) -> memref<?xf32>
22+
// CHECK-NEXT return %[[V0]] : memref<?xf32>
23+
// CHECK-NEXT }
24+
25+
// CHECK: func @_Z4foo2Pf(%[[arg0:.+]]: memref<?xf32>)
26+
// CHECK-NEXT %[[c32:.+]] = arith.constant 32 : index
27+
// CHECK-NEXT %[[V0:.+]] = "polygeist.subindex"(%[[arg0]], %[[c32]]) : (memref<?xf32>, index) -> memref<?xf32>
28+
// CHECK-NEXT return %[[V0]] : memref<?xf32>
29+
// CHECK-NEXT }

0 commit comments

Comments
 (0)