Skip to content

Commit f4c806d

Browse files
authored
Correct struct pointer offset calc (#279)
* Correct struct pointer offset calc * fix test * fix test
1 parent 11ed9b3 commit f4c806d

File tree

2 files changed

+44
-2
lines changed

2 files changed

+44
-2
lines changed

tools/cgeist/Lib/clang-mlir.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2740,12 +2740,20 @@ ValueCategory MLIRScanner::VisitBinaryOperator(clang::BinaryOperator *BO) {
27402740
auto lhs_v = lhs.getValue(loc, builder);
27412741
auto rhs_v = rhs.getValue(loc, builder);
27422742
if (auto mt = lhs_v.getType().dyn_cast<mlir::MemRefType>()) {
2743+
mlir::Type innerType = mt.getElementType();
2744+
auto shape = mt.getShape();
2745+
for (size_t i = 1; i < shape.size(); i++)
2746+
innerType = LLVM::LLVMArrayType::get(innerType, shape[i]);
27432747
lhs_v = builder.create<polygeist::Memref2PointerOp>(
2744-
loc, LLVM::LLVMPointerType::get(mt.getElementType()), lhs_v);
2748+
loc, LLVM::LLVMPointerType::get(innerType), lhs_v);
27452749
}
27462750
if (auto mt = rhs_v.getType().dyn_cast<mlir::MemRefType>()) {
2751+
mlir::Type innerType = mt.getElementType();
2752+
auto shape = mt.getShape();
2753+
for (size_t i = 1; i < shape.size(); i++)
2754+
innerType = LLVM::LLVMArrayType::get(innerType, shape[i]);
27472755
rhs_v = builder.create<polygeist::Memref2PointerOp>(
2748-
loc, LLVM::LLVMPointerType::get(mt.getElementType()), rhs_v);
2756+
loc, LLVM::LLVMPointerType::get(innerType), rhs_v);
27492757
}
27502758
if (lhs_v.getType().isa<mlir::FloatType>()) {
27512759
assert(rhs_v.getType() == lhs_v.getType());
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// RUN: cgeist %s --function=* -S | FileCheck %s
2+
3+
struct latLong
4+
{
5+
int lat;
6+
int lng;
7+
};
8+
9+
int foo(struct latLong *a, struct latLong *b) {
10+
return a - b;
11+
}
12+
struct latLong *bar(struct latLong *a, int b) {
13+
return a - b;
14+
}
15+
16+
// CHECK: func.func @foo(%arg0: memref<?x2xi32>, %arg1: memref<?x2xi32>) -> i32
17+
// CHECK-NEXT: %c8_i64 = arith.constant 8 : i64
18+
// CHECK-DAG: %[[i0:.*]] = "polygeist.memref2pointer"(%arg0) : (memref<?x2xi32>) -> !llvm.ptr<array<2 x i32>>
19+
// CHECK-DAG: %[[i1:.*]] = "polygeist.memref2pointer"(%arg1) : (memref<?x2xi32>) -> !llvm.ptr<array<2 x i32>>
20+
// CHECK-DAG: %[[i2:.*]] = llvm.ptrtoint %[[i0]] : !llvm.ptr<array<2 x i32>> to i64
21+
// CHECK-DAG: %[[i3:.*]] = llvm.ptrtoint %[[i1]] : !llvm.ptr<array<2 x i32>> to i64
22+
// CHECK-NEXT: %4 = arith.subi %[[i2]], %[[i3]] : i64
23+
// CHECK-NEXT: %5 = arith.divsi %4, %c8_i64 : i64
24+
// CHECK-NEXT: %6 = arith.trunci %5 : i64 to i32
25+
// CHECK-NEXT: return %6 : i32
26+
// CHECK-NEXT: }
27+
// CHECK: func.func @bar(%arg0: memref<?x2xi32>, %arg1: i32) -> memref<?x2xi32>
28+
// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32
29+
// CHECK-NEXT: %0 = "polygeist.memref2pointer"(%arg0) : (memref<?x2xi32>) -> !llvm.ptr<array<2 x i32>>
30+
// CHECK-NEXT: %1 = arith.subi %c0_i32, %arg1 : i32
31+
// CHECK-NEXT: %2 = llvm.getelementptr %0[%1] : (!llvm.ptr<array<2 x i32>>, i32) -> !llvm.ptr<array<2 x i32>>
32+
// CHECK-NEXT: %3 = "polygeist.pointer2memref"(%2) : (!llvm.ptr<array<2 x i32>>) -> memref<?x2xi32>
33+
// CHECK-NEXT: return %3 : memref<?x2xi32>
34+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)