Skip to content

Commit 5d40003

Browse files
authored
Update VisitExtVectorElementExpr to handle memref-abi and llvm-abi (#269)
* Fix ext_vector_type test; the test did not correctly checked the input * update ext_vector_type test to check for memref-abi equal to 0 and 1 * Update VisitExtVectorElementExpr to handle memref-abi and llvm-abi * Update tests * apply clang-format * update if else format * replace vector's type size_t by float within the test Co-authored-by: Jefferson Le Quellec <[email protected]>
1 parent f6a9282 commit 5d40003

File tree

2 files changed

+81
-30
lines changed

2 files changed

+81
-30
lines changed

tools/cgeist/Lib/clang-mlir.cc

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -463,28 +463,56 @@ mlir::Value MLIRScanner::createAllocOp(mlir::Type t, VarDecl *name,
463463
ValueCategory
464464
MLIRScanner::VisitExtVectorElementExpr(clang::ExtVectorElementExpr *expr) {
465465
auto base = Visit(expr->getBase());
466+
466467
SmallVector<uint32_t, 4> indices;
467468
expr->getEncodedElementAccess(indices);
468469
assert(indices.size() == 1 &&
469470
"The support for higher dimensions to be implemented.");
470-
auto loc = getMLIRLocation(expr->getExprLoc());
471-
auto idx = castToIndex(getMLIRLocation(expr->getAccessorLoc()),
472-
builder.create<ConstantIntOp>(loc, indices[0], 32));
471+
473472
assert(base.isReference);
474473
base.isReference = false;
475-
auto mt = base.val.getType().cast<MemRefType>();
476-
auto shape = std::vector<int64_t>(mt.getShape());
477-
if (shape.size() == 1) {
478-
shape[0] = -1;
474+
475+
const auto et = base.val.getType();
476+
assert(et.isa<LLVM::LLVMPointerType>() || et.isa<MemRefType>());
477+
478+
ValueCategory result = nullptr;
479+
const auto exprLoc = getMLIRLocation(expr->getExprLoc());
480+
const auto accLoc = getMLIRLocation(expr->getAccessorLoc());
481+
const mlir::Value idxs[2] = {
482+
builder.create<ConstantIntOp>(exprLoc, 0, 32),
483+
builder.create<ConstantIntOp>(exprLoc, indices[0], 32),
484+
};
485+
486+
if (const auto pt = et.dyn_cast<LLVM::LLVMPointerType>()) {
487+
auto pt0 =
488+
pt.getElementType().cast<mlir::LLVM::LLVMArrayType>().getElementType();
489+
base.val = builder.create<mlir::LLVM::GEPOp>(
490+
exprLoc, mlir::LLVM::LLVMPointerType::get(pt0, pt.getAddressSpace()),
491+
base.val, idxs);
492+
493+
result = ValueCategory(base.val, true);
494+
} else if (const auto mt = et.dyn_cast<MemRefType>()) {
495+
auto shape = std::vector<int64_t>(mt.getShape());
496+
497+
if (shape.size() == 1) {
498+
shape[0] = -1;
499+
} else {
500+
shape.erase(shape.begin());
501+
}
502+
503+
auto mt0 =
504+
mlir::MemRefType::get(shape, mt.getElementType(),
505+
MemRefLayoutAttrInterface(), mt.getMemorySpace());
506+
base.val = builder.create<polygeist::SubIndexOp>(
507+
exprLoc, mt0, base.val, castToIndex(accLoc, idxs[0]));
508+
509+
result = CommonArrayLookup(exprLoc, base, castToIndex(accLoc, idxs[1]),
510+
base.isReference);
479511
} else {
480-
shape.erase(shape.begin());
512+
llvm_unreachable("Unexpected MLIR type received");
481513
}
482-
auto mt0 =
483-
mlir::MemRefType::get(shape, mt.getElementType(),
484-
MemRefLayoutAttrInterface(), mt.getMemorySpace());
485-
base.val = builder.create<polygeist::SubIndexOp>(loc, mt0, base.val,
486-
getConstantIndex(0));
487-
return CommonArrayLookup(loc, base, idx, base.isReference);
514+
515+
return result;
488516
}
489517

490518
ValueCategory MLIRScanner::VisitConstantExpr(clang::ConstantExpr *expr) {
@@ -5108,8 +5136,12 @@ mlir::Type MLIRASTConsumer::getMLIRType(clang::QualType qt, bool *implicitRef,
51085136
}
51095137
if (!memRefABI || !allowMerge ||
51105138
ET.isa<LLVM::LLVMPointerType, LLVM::LLVMArrayType,
5111-
LLVM::LLVMFunctionType, LLVM::LLVMStructType>())
5112-
return LLVM::LLVMFixedVectorType::get(ET, size);
5139+
LLVM::LLVMFunctionType, LLVM::LLVMStructType>()) {
5140+
if (mlir::LLVM::LLVMFixedVectorType::isValidElementType(ET)) {
5141+
return mlir::LLVM::LLVMFixedVectorType::get(ET, size);
5142+
}
5143+
return mlir::LLVM::LLVMArrayType::get(ET, size);
5144+
}
51135145
if (implicitRef)
51145146
*implicitRef = true;
51155147
return mlir::MemRefType::get({size}, ET);
Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,41 @@
1-
// RUN: cgeist %s --function=* -S | FileCheck %s
1+
// RUN: cgeist %s --function=* -memref-abi=1 -S | FileCheck %s
2+
// RUN: cgeist %s --function=* -memref-abi=0 -S | FileCheck %s -check-prefix=CHECK2
23

3-
typedef size_t size_t_vec __attribute__((ext_vector_type(3)));
4+
typedef float float_vec __attribute__((ext_vector_type(3)));
45

5-
size_t evt(size_t_vec stv) {
6+
float evt(float_vec stv) {
67
return stv.x;
78
}
89

9-
extern "C" const size_t_vec stv;
10-
size_t evt2() {
10+
extern "C" const float_vec stv;
11+
float evt2() {
1112
return stv.x;
1213
}
1314

14-
// CHECK: func.func @_Z3evtDv3_i(%arg0: memref<?x3xi32>) -> i32 attributes {llvm.linkage = #llvm.linkage<external>}
15-
// CHECK-NEXT: %0 = affine.load %arg0[0, 0] : memref<?x3xi32>
16-
// CHECK-NEXT: return %0 : i32
17-
// CHECK-NEXT: }
18-
// CHECK: func.func @_Z4evt2v() -> i32 attributes {llvm.linkage = #llvm.linkage<external>} {
19-
// CHECK-NEXT: %0 = memref.get_global @stv : memref<3xi32>
20-
// CHECK-NEXT: %1 = affine.load %0[0] : memref<3xi32>
21-
// CHECK-NEXT: return %1 : i32
22-
// CHECK-NEXT: }
15+
// CHECK: memref.global @stv : memref<3xf32>
16+
// CHECK: func.func @_Z3evtDv3_f(%arg0: memref<?x3xf32>) -> f32 attributes {llvm.linkage = #llvm.linkage<external>} {
17+
// CHECK-NEXT: %0 = affine.load %arg0[0, 0] : memref<?x3xf32>
18+
// CHECK-NEXT: return %0 : f32
19+
// CHECK-NEXT: }
20+
// CHECK: func.func @_Z4evt2v() -> f32 attributes {llvm.linkage = #llvm.linkage<external>} {
21+
// CHECK-NEXT: %0 = memref.get_global @stv : memref<3xf32>
22+
// CHECK-NEXT: %1 = affine.load %0[0] : memref<3xf32>
23+
// CHECK-NEXT: return %1 : f32
24+
// CHECK-NEXT: }
25+
26+
// CHECK2: llvm.mlir.global external @stv() {addr_space = 0 : i32} : !llvm.array<3 x f32>
27+
// CHECK2: func.func @_Z3evtDv3_f(%arg0: !llvm.array<3 x f32>) -> f32 attributes {llvm.linkage = #llvm.linkage<external>} {
28+
// CHECK2-NEXT: %c1_i64 = arith.constant 1 : i64
29+
// CHECK2-NEXT: %0 = llvm.alloca %c1_i64 x !llvm.array<3 x f32> : (i64) -> !llvm.ptr<array<3 x f32>>
30+
// CHECK2-NEXT: llvm.store %arg0, %0 : !llvm.ptr<array<3 x f32>>
31+
// CHECK2-NEXT: %1 = llvm.getelementptr %0[0, 0] : (!llvm.ptr<array<3 x f32>>) -> !llvm.ptr<f32>
32+
// CHECK2-NEXT: %2 = llvm.load %1 : !llvm.ptr<f32>
33+
// CHECK2-NEXT: return %2 : f32
34+
// CHECK2-NEXT: }
35+
// CHECK2: func.func @_Z4evt2v() -> f32 attributes {llvm.linkage = #llvm.linkage<external>} {
36+
// CHECK2-NEXT: %0 = llvm.mlir.addressof @stv : !llvm.ptr<array<3 x f32>>
37+
// CHECK2-NEXT: %1 = llvm.getelementptr %0[0, 0] : (!llvm.ptr<array<3 x f32>>) -> !llvm.ptr<f32>
38+
// CHECK2-NEXT: %2 = llvm.load %1 : !llvm.ptr<f32>
39+
// CHECK2-NEXT: return %2 : f32
40+
// CHECK2-NEXT: }
41+

0 commit comments

Comments
 (0)