@@ -463,28 +463,56 @@ mlir::Value MLIRScanner::createAllocOp(mlir::Type t, VarDecl *name,
463
463
ValueCategory
464
464
MLIRScanner::VisitExtVectorElementExpr (clang::ExtVectorElementExpr *expr) {
465
465
auto base = Visit (expr->getBase ());
466
+
466
467
SmallVector<uint32_t , 4 > indices;
467
468
expr->getEncodedElementAccess (indices);
468
469
assert (indices.size () == 1 &&
469
470
" The support for higher dimensions to be implemented." );
470
- auto loc = getMLIRLocation (expr->getExprLoc ());
471
- auto idx = castToIndex (getMLIRLocation (expr->getAccessorLoc ()),
472
- builder.create <ConstantIntOp>(loc, indices[0 ], 32 ));
471
+
473
472
assert (base.isReference );
474
473
base.isReference = false ;
475
- auto mt = base.val .getType ().cast <MemRefType>();
476
- auto shape = std::vector<int64_t >(mt.getShape ());
477
- if (shape.size () == 1 ) {
478
- shape[0 ] = -1 ;
474
+
475
+ const auto et = base.val .getType ();
476
+ assert (et.isa <LLVM::LLVMPointerType>() || et.isa <MemRefType>());
477
+
478
+ ValueCategory result = nullptr ;
479
+ const auto exprLoc = getMLIRLocation (expr->getExprLoc ());
480
+ const auto accLoc = getMLIRLocation (expr->getAccessorLoc ());
481
+ const mlir::Value idxs[2 ] = {
482
+ builder.create <ConstantIntOp>(exprLoc, 0 , 32 ),
483
+ builder.create <ConstantIntOp>(exprLoc, indices[0 ], 32 ),
484
+ };
485
+
486
+ if (const auto pt = et.dyn_cast <LLVM::LLVMPointerType>()) {
487
+ auto pt0 =
488
+ pt.getElementType ().cast <mlir::LLVM::LLVMArrayType>().getElementType ();
489
+ base.val = builder.create <mlir::LLVM::GEPOp>(
490
+ exprLoc, mlir::LLVM::LLVMPointerType::get (pt0, pt.getAddressSpace ()),
491
+ base.val , idxs);
492
+
493
+ result = ValueCategory (base.val , true );
494
+ } else if (const auto mt = et.dyn_cast <MemRefType>()) {
495
+ auto shape = std::vector<int64_t >(mt.getShape ());
496
+
497
+ if (shape.size () == 1 ) {
498
+ shape[0 ] = -1 ;
499
+ } else {
500
+ shape.erase (shape.begin ());
501
+ }
502
+
503
+ auto mt0 =
504
+ mlir::MemRefType::get (shape, mt.getElementType (),
505
+ MemRefLayoutAttrInterface (), mt.getMemorySpace ());
506
+ base.val = builder.create <polygeist::SubIndexOp>(
507
+ exprLoc, mt0, base.val , castToIndex (accLoc, idxs[0 ]));
508
+
509
+ result = CommonArrayLookup (exprLoc, base, castToIndex (accLoc, idxs[1 ]),
510
+ base.isReference );
479
511
} else {
480
- shape. erase (shape. begin () );
512
+ llvm_unreachable ( " Unexpected MLIR type received " );
481
513
}
482
- auto mt0 =
483
- mlir::MemRefType::get (shape, mt.getElementType (),
484
- MemRefLayoutAttrInterface (), mt.getMemorySpace ());
485
- base.val = builder.create <polygeist::SubIndexOp>(loc, mt0, base.val ,
486
- getConstantIndex (0 ));
487
- return CommonArrayLookup (loc, base, idx, base.isReference );
514
+
515
+ return result;
488
516
}
489
517
490
518
ValueCategory MLIRScanner::VisitConstantExpr (clang::ConstantExpr *expr) {
@@ -5108,8 +5136,12 @@ mlir::Type MLIRASTConsumer::getMLIRType(clang::QualType qt, bool *implicitRef,
5108
5136
}
5109
5137
if (!memRefABI || !allowMerge ||
5110
5138
ET.isa <LLVM::LLVMPointerType, LLVM::LLVMArrayType,
5111
- LLVM::LLVMFunctionType, LLVM::LLVMStructType>())
5112
- return LLVM::LLVMFixedVectorType::get (ET, size);
5139
+ LLVM::LLVMFunctionType, LLVM::LLVMStructType>()) {
5140
+ if (mlir::LLVM::LLVMFixedVectorType::isValidElementType (ET)) {
5141
+ return mlir::LLVM::LLVMFixedVectorType::get (ET, size);
5142
+ }
5143
+ return mlir::LLVM::LLVMArrayType::get (ET, size);
5144
+ }
5113
5145
if (implicitRef)
5114
5146
*implicitRef = true ;
5115
5147
return mlir::MemRefType::get ({size}, ET);
0 commit comments