27
27
#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
28
28
#include " mlir/Dialect/Math/IR/Math.h"
29
29
#include " mlir/Dialect/MemRef/IR/MemRef.h"
30
+ #include " mlir/Dialect/MemRef/Utils/MemRefUtils.h"
30
31
#include " mlir/Dialect/NamedTuple/IR/NamedTuple.h"
32
+ #include " mlir/Dialect/NamedTuple/IR/NamedTupleDialect.h"
33
+ #include " mlir/Dialect/NamedTuple/IR/NamedTupleTypes.h"
31
34
#include " mlir/Dialect/SCF/IR/SCF.h"
32
35
#include " mlir/Dialect/Vector/IR/VectorOps.h"
36
+ #include " mlir/IR/Attributes.h"
37
+ #include " mlir/IR/BuiltinAttributes.h"
33
38
#include " mlir/IR/BuiltinDialect.h"
39
+ #include " mlir/IR/BuiltinOps.h"
34
40
#include " mlir/IR/BuiltinTypes.h"
35
41
#include " mlir/IR/Operation.h"
36
42
#include " mlir/IR/Region.h"
45
51
#include " mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
46
52
#include " mlir/Target/LLVMIR/Export.h"
47
53
#include " mlir/Transforms/DialectConversion.h"
54
+ #include " clang/CIR/Dialect/IR/CIRDataLayout.h"
48
55
#include " clang/CIR/Dialect/IR/CIRDialect.h"
49
56
#include " clang/CIR/Dialect/IR/CIRTypes.h"
50
57
#include " clang/CIR/LowerToLLVM.h"
51
58
#include " clang/CIR/LowerToMLIR.h"
52
59
#include " clang/CIR/LoweringHelpers.h"
53
60
#include " clang/CIR/Passes.h"
61
+ #include " llvm/ADT/ArrayRef.h"
54
62
#include " llvm/ADT/STLExtras.h"
55
63
#include " llvm/ADT/SmallVector.h"
56
64
#include " llvm/ADT/TypeSwitch.h"
@@ -172,7 +180,7 @@ class CIRAllocaOpLowering : public mlir::OpConversionPattern<cir::AllocaOp> {
172
180
mlir::Type mlirType =
173
181
convertTypeForMemory (*getTypeConverter (), adaptor.getAllocaType ());
174
182
175
- // FIXME: Some types can not be converted yet (e.g. struct)
183
+ // FIXME: Some types can not be converted yet
176
184
if (!mlirType)
177
185
return mlir::LogicalResult::failure ();
178
186
@@ -281,6 +289,71 @@ class CIRStoreOpLowering : public mlir::OpConversionPattern<cir::StoreOp> {
281
289
}
282
290
};
283
291
292
+ // Lower cir.get_member
293
+ //
294
+ // clang-format off
295
+ //
296
+ // %5 = cir.get_member %1[1] {name = "b"} : !cir.ptr<!named_tuple.named_tuple<"s", [i32, f64, i8]>> -> !cir.ptr<!cir.double>
297
+ //
298
+ // to something like
299
+ //
300
+ // %1 = named_tuple.cast %alloca_0 : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
301
+ // %c8 = arith.constant 8 : index
302
+ // %view_1 = memref.view %1[%c8][] : memref<24xi8> to memref<f64>
303
+ // clang-format on
304
+ class CIRGetMemberOpLowering
305
+ : public mlir::OpConversionPattern<cir::GetMemberOp> {
306
+ cir::CIRDataLayout const &dataLayout;
307
+
308
+ public:
309
+ using mlir::OpConversionPattern<cir::GetMemberOp>::OpConversionPattern;
310
+
311
+ CIRGetMemberOpLowering (const mlir::TypeConverter &typeConverter,
312
+ mlir::MLIRContext *context,
313
+ cir::CIRDataLayout const &dataLayout)
314
+ : OpConversionPattern{typeConverter, context}, dataLayout{dataLayout} {}
315
+
316
+ mlir::LogicalResult
317
+ matchAndRewrite (cir::GetMemberOp op, OpAdaptor adaptor,
318
+ mlir::ConversionPatternRewriter &rewriter) const override {
319
+ auto pointeeType = op.getAddrTy ().getPointee ();
320
+ if (!mlir::isa<cir::StructType>(pointeeType))
321
+ op.emitError (" GetMemberOp only works on pointer to cir::StructType" );
322
+ auto structType = mlir::cast<cir::StructType>(pointeeType);
323
+ // For now, just rely on the datalayout of the high-level type since the
324
+ // datalayout of low-level type is not implemented yet. But since C++ is a
325
+ // concrete datalayout, both datalayouts are the same.
326
+ auto *structLayout = dataLayout.getStructLayout (structType);
327
+
328
+ // Get the lowered type: memref<!named_tuple.named_tuple<>>
329
+ auto memref = mlir::cast<mlir::MemRefType>(adaptor.getAddr ().getType ());
330
+ // Alias the memref of struct to a memref of an i8 array of the same size.
331
+ const std::array linearizedSize{
332
+ static_cast <std::int64_t >(dataLayout.getTypeStoreSize (structType))};
333
+ auto flattenMemRef = mlir::MemRefType::get (
334
+ linearizedSize, mlir::IntegerType::get (memref.getContext (), 8 ));
335
+ // Use a special cast because normal memref cast cannot do such an extreme
336
+ // cast.
337
+ auto bytesMemRef = rewriter.create <mlir::named_tuple::CastOp>(
338
+ op.getLoc (), mlir::TypeRange{flattenMemRef},
339
+ mlir::ValueRange{adaptor.getAddr ()});
340
+
341
+ auto memberIndex = op.getIndex ();
342
+ auto namedTupleType =
343
+ mlir::cast<mlir::named_tuple::NamedTupleType>(memref.getElementType ());
344
+ // The lowered type of the element to access in the named_tuple.
345
+ auto loweredMemberType = namedTupleType.getType (memberIndex);
346
+ auto elementMemRefTy = mlir::MemRefType::get ({}, loweredMemberType);
347
+ auto offset = structLayout->getElementOffset (memberIndex);
348
+ // Synthesize the byte access to right lowered type.
349
+ auto byteShift =
350
+ rewriter.create <mlir::arith::ConstantIndexOp>(op.getLoc (), offset);
351
+ rewriter.replaceOpWithNewOp <mlir::memref::ViewOp>(
352
+ op, elementMemRefTy, bytesMemRef, byteShift, mlir::ValueRange{});
353
+ return mlir::LogicalResult::success ();
354
+ }
355
+ };
356
+
284
357
// / Converts CIR unary math ops (e.g., cir::SinOp) to their MLIR equivalents
285
358
// / (e.g., math::SinOp) using a generic template to avoid redundant boilerplate
286
359
// / matchAndRewrite definitions.
@@ -1277,7 +1350,8 @@ class CIRTrapOpLowering : public mlir::OpConversionPattern<cir::TrapOp> {
1277
1350
};
1278
1351
1279
1352
void populateCIRToMLIRConversionPatterns (mlir::RewritePatternSet &patterns,
1280
- mlir::TypeConverter &converter) {
1353
+ mlir::TypeConverter &converter,
1354
+ cir::CIRDataLayout &cirDataLayout) {
1281
1355
patterns.add <CIRReturnLowering, CIRBrOpLowering>(patterns.getContext ());
1282
1356
1283
1357
patterns
@@ -1298,6 +1372,8 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
1298
1372
CIRVectorExtractLowering, CIRVectorCmpOpLowering, CIRACosOpLowering,
1299
1373
CIRASinOpLowering, CIRUnreachableOpLowering, CIRTanOpLowering,
1300
1374
CIRTrapOpLowering>(converter, patterns.getContext ());
1375
+ patterns.add <CIRGetMemberOpLowering>(converter, patterns.getContext (),
1376
+ cirDataLayout);
1301
1377
}
1302
1378
1303
1379
mlir::TypeConverter prepareTypeConverter (mlir::DataLayout &dataLayout) {
@@ -1354,7 +1430,7 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
1354
1430
curType = arrayType.getEltType ();
1355
1431
}
1356
1432
auto elementType = converter.convertType (curType);
1357
- // FIXME: The element type might not be converted (e.g. struct)
1433
+ // FIXME: The element type might not be converted
1358
1434
if (!elementType)
1359
1435
return nullptr ;
1360
1436
return mlir::MemRefType::get (shape, elementType);
@@ -1396,19 +1472,21 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
1396
1472
void ConvertCIRToMLIRPass::runOnOperation () {
1397
1473
auto module = getOperation ();
1398
1474
mlir::DataLayout dataLayout{module };
1475
+ cir::CIRDataLayout cirDataLayout{module };
1399
1476
auto converter = prepareTypeConverter (dataLayout);
1400
1477
1401
1478
mlir::RewritePatternSet patterns (&getContext ());
1402
1479
1403
1480
populateCIRLoopToSCFConversionPatterns (patterns, converter);
1404
- populateCIRToMLIRConversionPatterns (patterns, converter);
1481
+ populateCIRToMLIRConversionPatterns (patterns, converter, cirDataLayout );
1405
1482
1406
1483
mlir::ConversionTarget target (getContext ());
1407
1484
target.addLegalOp <mlir::ModuleOp>();
1408
1485
target.addLegalDialect <mlir::affine::AffineDialect, mlir::arith::ArithDialect,
1409
1486
mlir::memref::MemRefDialect, mlir::func::FuncDialect,
1410
1487
mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect,
1411
1488
mlir::math::MathDialect, mlir::vector::VectorDialect,
1489
+ mlir::named_tuple::NamedTupleDialect,
1412
1490
mlir::LLVM::LLVMDialect>();
1413
1491
target.addIllegalDialect <cir::CIRDialect>();
1414
1492
0 commit comments