29
29
#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
30
30
#include " mlir/Dialect/Math/IR/Math.h"
31
31
#include " mlir/Dialect/MemRef/IR/MemRef.h"
32
+ #include " mlir/Dialect/MemRef/Utils/MemRefUtils.h"
32
33
#include " mlir/Dialect/NamedTuple/IR/NamedTuple.h"
34
+ #include " mlir/Dialect/NamedTuple/IR/NamedTupleDialect.h"
35
+ #include " mlir/Dialect/NamedTuple/IR/NamedTupleTypes.h"
33
36
#include " mlir/Dialect/SCF/IR/SCF.h"
34
37
#include " mlir/Dialect/SCF/Transforms/Passes.h"
35
38
#include " mlir/Dialect/Vector/IR/VectorOps.h"
39
+ #include " mlir/IR/Attributes.h"
40
+ #include " mlir/IR/BuiltinAttributes.h"
36
41
#include " mlir/IR/BuiltinDialect.h"
42
+ #include " mlir/IR/BuiltinOps.h"
37
43
#include " mlir/IR/BuiltinTypes.h"
38
44
#include " mlir/IR/Operation.h"
39
45
#include " mlir/IR/Region.h"
48
54
#include " mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
49
55
#include " mlir/Target/LLVMIR/Export.h"
50
56
#include " mlir/Transforms/DialectConversion.h"
57
+ #include " clang/CIR/Dialect/IR/CIRDataLayout.h"
51
58
#include " clang/CIR/Dialect/IR/CIRDialect.h"
52
59
#include " clang/CIR/Dialect/IR/CIRTypes.h"
53
60
#include " clang/CIR/LowerToLLVM.h"
54
61
#include " clang/CIR/LowerToMLIR.h"
55
62
#include " clang/CIR/LoweringHelpers.h"
56
63
#include " clang/CIR/Passes.h"
64
+ #include " llvm/ADT/ArrayRef.h"
57
65
#include " llvm/ADT/STLExtras.h"
58
66
#include " llvm/ADT/Sequence.h"
59
67
#include " llvm/ADT/SmallVector.h"
@@ -176,7 +184,7 @@ class CIRAllocaOpLowering : public mlir::OpConversionPattern<cir::AllocaOp> {
176
184
mlir::Type mlirType =
177
185
convertTypeForMemory (*getTypeConverter (), adaptor.getAllocaType ());
178
186
179
- // FIXME: Some types can not be converted yet (e.g. struct)
187
+ // FIXME: Some types can not be converted yet
180
188
if (!mlirType)
181
189
return mlir::LogicalResult::failure ();
182
190
@@ -285,6 +293,71 @@ class CIRStoreOpLowering : public mlir::OpConversionPattern<cir::StoreOp> {
285
293
}
286
294
};
287
295
296
+ // Lower cir.get_member
297
+ //
298
+ // clang-format off
299
+ //
300
+ // %5 = cir.get_member %1[1] {name = "b"} : !cir.ptr<!named_tuple.named_tuple<"s", [i32, f64, i8]>> -> !cir.ptr<!cir.double>
301
+ //
302
+ // to something like
303
+ //
304
+ // %1 = named_tuple.cast %alloca_0 : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
305
+ // %c8 = arith.constant 8 : index
306
+ // %view_1 = memref.view %1[%c8][] : memref<24xi8> to memref<f64>
307
+ // clang-format on
308
+ class CIRGetMemberOpLowering
309
+ : public mlir::OpConversionPattern<cir::GetMemberOp> {
310
+ cir::CIRDataLayout const &dataLayout;
311
+
312
+ public:
313
+ using mlir::OpConversionPattern<cir::GetMemberOp>::OpConversionPattern;
314
+
315
+ CIRGetMemberOpLowering (const mlir::TypeConverter &typeConverter,
316
+ mlir::MLIRContext *context,
317
+ cir::CIRDataLayout const &dataLayout)
318
+ : OpConversionPattern{typeConverter, context}, dataLayout{dataLayout} {}
319
+
320
+ mlir::LogicalResult
321
+ matchAndRewrite (cir::GetMemberOp op, OpAdaptor adaptor,
322
+ mlir::ConversionPatternRewriter &rewriter) const override {
323
+ auto pointeeType = op.getAddrTy ().getPointee ();
324
+ if (!mlir::isa<cir::StructType>(pointeeType))
325
+ op.emitError (" GetMemberOp only works on pointer to cir::StructType" );
326
+ auto structType = mlir::cast<cir::StructType>(pointeeType);
327
+ // For now, just rely on the datalayout of the high-level type since the
328
+ // datalayout of low-level type is not implemented yet. But since C++ is a
329
+ // concrete datalayout, both datalayouts are the same.
330
+ auto *structLayout = dataLayout.getStructLayout (structType);
331
+
332
+ // Get the lowered type: memref<!named_tuple.named_tuple<>>
333
+ auto memref = mlir::cast<mlir::MemRefType>(adaptor.getAddr ().getType ());
334
+ // Alias the memref of struct to a memref of an i8 array of the same size.
335
+ const std::array linearizedSize{
336
+ static_cast <std::int64_t >(dataLayout.getTypeStoreSize (structType))};
337
+ auto flattenMemRef = mlir::MemRefType::get (
338
+ linearizedSize, mlir::IntegerType::get (memref.getContext (), 8 ));
339
+ // Use a special cast because normal memref cast cannot do such an extreme
340
+ // cast.
341
+ auto bytesMemRef = rewriter.create <mlir::named_tuple::CastOp>(
342
+ op.getLoc (), mlir::TypeRange{flattenMemRef},
343
+ mlir::ValueRange{adaptor.getAddr ()});
344
+
345
+ auto memberIndex = op.getIndex ();
346
+ auto namedTupleType =
347
+ mlir::cast<mlir::named_tuple::NamedTupleType>(memref.getElementType ());
348
+ // The lowered type of the element to access in the named_tuple.
349
+ auto loweredMemberType = namedTupleType.getType (memberIndex);
350
+ auto elementMemRefTy = mlir::MemRefType::get ({}, loweredMemberType);
351
+ auto offset = structLayout->getElementOffset (memberIndex);
352
+ // Synthesize the byte access to right lowered type.
353
+ auto byteShift =
354
+ rewriter.create <mlir::arith::ConstantIndexOp>(op.getLoc (), offset);
355
+ rewriter.replaceOpWithNewOp <mlir::memref::ViewOp>(
356
+ op, elementMemRefTy, bytesMemRef, byteShift, mlir::ValueRange{});
357
+ return mlir::LogicalResult::success ();
358
+ }
359
+ };
360
+
288
361
class CIRCosOpLowering : public mlir ::OpConversionPattern<cir::CosOp> {
289
362
public:
290
363
using OpConversionPattern<cir::CosOp>::OpConversionPattern;
@@ -1359,7 +1432,8 @@ class CIRPtrStrideOpLowering
1359
1432
};
1360
1433
1361
1434
void populateCIRToMLIRConversionPatterns (mlir::RewritePatternSet &patterns,
1362
- mlir::TypeConverter &converter) {
1435
+ mlir::TypeConverter &converter,
1436
+ cir::CIRDataLayout &cirDataLayout) {
1363
1437
patterns.add <CIRReturnLowering, CIRBrOpLowering>(patterns.getContext ());
1364
1438
1365
1439
patterns.add <
@@ -1378,6 +1452,8 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
1378
1452
CIRIfOpLowering, CIRVectorCreateLowering, CIRVectorInsertLowering,
1379
1453
CIRVectorExtractLowering, CIRVectorCmpOpLowering>(converter,
1380
1454
patterns.getContext ());
1455
+ patterns.add <CIRGetMemberOpLowering>(converter, patterns.getContext (),
1456
+ cirDataLayout);
1381
1457
}
1382
1458
1383
1459
mlir::TypeConverter prepareTypeConverter (mlir::DataLayout &dataLayout) {
@@ -1434,7 +1510,7 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
1434
1510
curType = arrayType.getEltType ();
1435
1511
}
1436
1512
auto elementType = converter.convertType (curType);
1437
- // FIXME: The element type might not be converted (e.g. struct)
1513
+ // FIXME: The element type might not be converted
1438
1514
if (!elementType)
1439
1515
return nullptr ;
1440
1516
return mlir::MemRefType::get (shape, elementType);
@@ -1476,20 +1552,21 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
1476
1552
void ConvertCIRToMLIRPass::runOnOperation () {
1477
1553
auto module = getOperation ();
1478
1554
mlir::DataLayout dataLayout{module };
1555
+ cir::CIRDataLayout cirDataLayout{module };
1479
1556
auto converter = prepareTypeConverter (dataLayout);
1480
1557
1481
1558
mlir::RewritePatternSet patterns (&getContext ());
1482
1559
1483
1560
populateCIRLoopToSCFConversionPatterns (patterns, converter);
1484
- populateCIRToMLIRConversionPatterns (patterns, converter);
1561
+ populateCIRToMLIRConversionPatterns (patterns, converter, cirDataLayout );
1485
1562
1486
1563
mlir::ConversionTarget target (getContext ());
1487
1564
target.addLegalOp <mlir::ModuleOp>();
1488
- target
1489
- . addLegalDialect < mlir::affine::AffineDialect , mlir::arith::ArithDialect ,
1490
- mlir::memref::MemRefDialect , mlir::func::FuncDialect ,
1491
- mlir::scf::SCFDialect , mlir::cf::ControlFlowDialect ,
1492
- mlir::math::MathDialect, mlir::vector::VectorDialect >();
1565
+ target. addLegalDialect <mlir::affine::AffineDialect, mlir::arith::ArithDialect,
1566
+ mlir::memref::MemRefDialect , mlir::func::FuncDialect ,
1567
+ mlir::scf::SCFDialect , mlir::cf::ControlFlowDialect ,
1568
+ mlir::math::MathDialect , mlir::vector::VectorDialect ,
1569
+ mlir::named_tuple::NamedTupleDialect >();
1493
1570
target.addIllegalDialect <cir::CIRDialect>();
1494
1571
1495
1572
if (runAtStartHook)
0 commit comments