@@ -1107,7 +1107,23 @@ ValueCategory MLIRScanner::VisitCXXNewExpr(clang::CXXNewExpr *expr) {
1107
1107
1108
1108
mlir::Value alloc;
1109
1109
mlir::Value arrayCons;
1110
- if (auto mt = ty.dyn_cast <mlir::MemRefType>()) {
1110
+ if (!expr->placement_arguments ().empty ()) {
1111
+ mlir::Value val = Visit (*expr->placement_arg_begin ()).getValue (builder);
1112
+ if (auto mt = ty.dyn_cast <mlir::MemRefType>()) {
1113
+ arrayCons = alloc =
1114
+ builder.create <polygeist::Pointer2MemrefOp>(loc, mt, val);
1115
+ } else {
1116
+ arrayCons = alloc = builder.create <mlir::LLVM::BitcastOp>(loc, ty, val);
1117
+ auto PT = ty.cast <LLVM::LLVMPointerType>();
1118
+ if (expr->isArray ())
1119
+ arrayCons = builder.create <mlir::LLVM::BitcastOp>(
1120
+ loc,
1121
+ LLVM::LLVMPointerType::get (
1122
+ LLVM::LLVMArrayType::get (PT.getElementType (), 0 ),
1123
+ PT.getAddressSpace ()),
1124
+ alloc);
1125
+ }
1126
+ } else if (auto mt = ty.dyn_cast <mlir::MemRefType>()) {
1111
1127
auto shape = std::vector<int64_t >(mt.getShape ());
1112
1128
mlir::Value args[1 ] = {count};
1113
1129
arrayCons = alloc = builder.create <mlir::memref::AllocOp>(loc, mt, args);
@@ -1116,7 +1132,7 @@ ValueCategory MLIRScanner::VisitCXXNewExpr(clang::CXXNewExpr *expr) {
1116
1132
auto typeSize = getTypeSize (expr->getAllocatedType ());
1117
1133
mlir::Value args[1 ] = {builder.create <arith::MulIOp>(loc, typeSize, count)};
1118
1134
args[0 ] = builder.create <IndexCastOp>(loc, i64 , args[0 ]);
1119
- alloc = builder.create <mlir::LLVM::BitcastOp>(
1135
+ arrayCons = alloc = builder.create <mlir::LLVM::BitcastOp>(
1120
1136
loc, ty,
1121
1137
builder
1122
1138
.create <mlir::LLVM::CallOp>(loc, Glob.GetOrCreateMallocFunction (),
0 commit comments