29
29
#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
30
30
#include " mlir/Dialect/Math/IR/Math.h"
31
31
#include " mlir/Dialect/MemRef/IR/MemRef.h"
32
+ #include " mlir/Dialect/MemRef/Utils/MemRefUtils.h"
32
33
#include " mlir/Dialect/NamedTuple/IR/NamedTuple.h"
34
+ #include " mlir/Dialect/NamedTuple/IR/NamedTupleDialect.h"
35
+ #include " mlir/Dialect/NamedTuple/IR/NamedTupleTypes.h"
33
36
#include " mlir/Dialect/SCF/IR/SCF.h"
34
37
#include " mlir/Dialect/SCF/Transforms/Passes.h"
35
38
#include " mlir/Dialect/Vector/IR/VectorOps.h"
39
+ #include " mlir/IR/Attributes.h"
40
+ #include " mlir/IR/BuiltinAttributes.h"
36
41
#include " mlir/IR/BuiltinDialect.h"
42
+ #include " mlir/IR/BuiltinOps.h"
37
43
#include " mlir/IR/BuiltinTypes.h"
38
44
#include " mlir/IR/Operation.h"
39
45
#include " mlir/IR/Region.h"
48
54
#include " mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
49
55
#include " mlir/Target/LLVMIR/Export.h"
50
56
#include " mlir/Transforms/DialectConversion.h"
57
+ #include " clang/CIR/Dialect/IR/CIRDataLayout.h"
51
58
#include " clang/CIR/Dialect/IR/CIRDialect.h"
52
59
#include " clang/CIR/Dialect/IR/CIRTypes.h"
53
60
#include " clang/CIR/LowerToMLIR.h"
54
61
#include " clang/CIR/LoweringHelpers.h"
55
62
#include " clang/CIR/Passes.h"
63
+ #include " llvm/ADT/ArrayRef.h"
56
64
#include " llvm/ADT/STLExtras.h"
57
65
#include " llvm/ADT/Sequence.h"
58
66
#include " llvm/ADT/SmallVector.h"
@@ -175,7 +183,7 @@ class CIRAllocaOpLowering : public mlir::OpConversionPattern<cir::AllocaOp> {
175
183
mlir::Type mlirType =
176
184
convertTypeForMemory (*getTypeConverter (), adaptor.getAllocaType ());
177
185
178
- // FIXME: Some types can not be converted yet (e.g. struct)
186
+ // FIXME: Some types can not be converted yet
179
187
if (!mlirType)
180
188
return mlir::LogicalResult::failure ();
181
189
@@ -284,6 +292,71 @@ class CIRStoreOpLowering : public mlir::OpConversionPattern<cir::StoreOp> {
284
292
}
285
293
};
286
294
295
+ // Lower cir.get_member
296
+ //
297
+ // clang-format off
298
+ //
299
+ // %5 = cir.get_member %1[1] {name = "b"} : !cir.ptr<!named_tuple.named_tuple<"s", [i32, f64, i8]>> -> !cir.ptr<!cir.double>
300
+ //
301
+ // to something like
302
+ //
303
+ // %1 = named_tuple.cast %alloca_0 : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
304
+ // %c8 = arith.constant 8 : index
305
+ // %view_1 = memref.view %1[%c8][] : memref<24xi8> to memref<f64>
306
+ // clang-format on
307
+ class CIRGetMemberOpLowering
308
+ : public mlir::OpConversionPattern<cir::GetMemberOp> {
309
+ cir::CIRDataLayout const &dataLayout;
310
+
311
+ public:
312
+ using mlir::OpConversionPattern<cir::GetMemberOp>::OpConversionPattern;
313
+
314
+ CIRGetMemberOpLowering (const mlir::TypeConverter &typeConverter,
315
+ mlir::MLIRContext *context,
316
+ cir::CIRDataLayout const &dataLayout)
317
+ : OpConversionPattern{typeConverter, context}, dataLayout{dataLayout} {}
318
+
319
+ mlir::LogicalResult
320
+ matchAndRewrite (cir::GetMemberOp op, OpAdaptor adaptor,
321
+ mlir::ConversionPatternRewriter &rewriter) const override {
322
+ auto pointeeType = op.getAddrTy ().getPointee ();
323
+ if (!mlir::isa<cir::StructType>(pointeeType))
324
+ op.emitError (" GetMemberOp only works on pointer to cir::StructType" );
325
+ auto structType = mlir::cast<cir::StructType>(pointeeType);
326
+ // For now, just rely on the datalayout of the high-level type since the
327
+ // datalayout of low-level type is not implemented yet. But since C++ is a
328
+ // concrete datalayout, both datalayouts are the same.
329
+ auto *structLayout = dataLayout.getStructLayout (structType);
330
+
331
+ // Get the lowered type: memref<!named_tuple.named_tuple<>>
332
+ auto memref = mlir::cast<mlir::MemRefType>(adaptor.getAddr ().getType ());
333
+ // Alias the memref of struct to a memref of an i8 array of the same size.
334
+ const std::array linearizedSize{
335
+ static_cast <std::int64_t >(dataLayout.getTypeStoreSize (structType))};
336
+ auto flattenMemRef = mlir::MemRefType::get (
337
+ linearizedSize, mlir::IntegerType::get (memref.getContext (), 8 ));
338
+ // Use a special cast because normal memref cast cannot do such an extreme
339
+ // cast.
340
+ auto bytesMemRef = rewriter.create <mlir::named_tuple::CastOp>(
341
+ op.getLoc (), mlir::TypeRange{flattenMemRef},
342
+ mlir::ValueRange{adaptor.getAddr ()});
343
+
344
+ auto memberIndex = op.getIndex ();
345
+ auto namedTupleType =
346
+ mlir::cast<mlir::named_tuple::NamedTupleType>(memref.getElementType ());
347
+ // The lowered type of the element to access in the named_tuple.
348
+ auto loweredMemberType = namedTupleType.getType (memberIndex);
349
+ auto elementMemRefTy = mlir::MemRefType::get ({}, loweredMemberType);
350
+ auto offset = structLayout->getElementOffset (memberIndex);
351
+ // Synthesize the byte access to right lowered type.
352
+ auto byteShift =
353
+ rewriter.create <mlir::arith::ConstantIndexOp>(op.getLoc (), offset);
354
+ rewriter.replaceOpWithNewOp <mlir::memref::ViewOp>(
355
+ op, elementMemRefTy, bytesMemRef, byteShift, mlir::ValueRange{});
356
+ return mlir::LogicalResult::success ();
357
+ }
358
+ };
359
+
287
360
class CIRCosOpLowering : public mlir ::OpConversionPattern<cir::CosOp> {
288
361
public:
289
362
using OpConversionPattern<cir::CosOp>::OpConversionPattern;
@@ -1360,7 +1433,8 @@ class CIRPtrStrideOpLowering
1360
1433
};
1361
1434
1362
1435
void populateCIRToMLIRConversionPatterns (mlir::RewritePatternSet &patterns,
1363
- mlir::TypeConverter &converter) {
1436
+ mlir::TypeConverter &converter,
1437
+ cir::CIRDataLayout &cirDataLayout) {
1364
1438
patterns.add <CIRReturnLowering, CIRBrOpLowering>(patterns.getContext ());
1365
1439
1366
1440
patterns.add <
@@ -1379,6 +1453,8 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
1379
1453
CIRIfOpLowering, CIRVectorCreateLowering, CIRVectorInsertLowering,
1380
1454
CIRVectorExtractLowering, CIRVectorCmpOpLowering>(converter,
1381
1455
patterns.getContext ());
1456
+ patterns.add <CIRGetMemberOpLowering>(converter, patterns.getContext (),
1457
+ cirDataLayout);
1382
1458
}
1383
1459
1384
1460
mlir::TypeConverter prepareTypeConverter (mlir::DataLayout &dataLayout) {
@@ -1435,7 +1511,7 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
1435
1511
curType = arrayType.getEltType ();
1436
1512
}
1437
1513
auto elementType = converter.convertType (curType);
1438
- // FIXME: The element type might not be converted (e.g. struct)
1514
+ // FIXME: The element type might not be converted
1439
1515
if (!elementType)
1440
1516
return nullptr ;
1441
1517
return mlir::MemRefType::get (shape, elementType);
@@ -1477,20 +1553,21 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
1477
1553
void ConvertCIRToMLIRPass::runOnOperation () {
1478
1554
auto module = getOperation ();
1479
1555
mlir::DataLayout dataLayout{module };
1556
+ cir::CIRDataLayout cirDataLayout{module };
1480
1557
auto converter = prepareTypeConverter (dataLayout);
1481
1558
1482
1559
mlir::RewritePatternSet patterns (&getContext ());
1483
1560
1484
1561
populateCIRLoopToSCFConversionPatterns (patterns, converter);
1485
- populateCIRToMLIRConversionPatterns (patterns, converter);
1562
+ populateCIRToMLIRConversionPatterns (patterns, converter, cirDataLayout );
1486
1563
1487
1564
mlir::ConversionTarget target (getContext ());
1488
1565
target.addLegalOp <mlir::ModuleOp>();
1489
- target
1490
- . addLegalDialect < mlir::affine::AffineDialect , mlir::arith::ArithDialect ,
1491
- mlir::memref::MemRefDialect , mlir::func::FuncDialect ,
1492
- mlir::scf::SCFDialect , mlir::cf::ControlFlowDialect ,
1493
- mlir::math::MathDialect, mlir::vector::VectorDialect >();
1566
+ target. addLegalDialect <mlir::affine::AffineDialect, mlir::arith::ArithDialect,
1567
+ mlir::memref::MemRefDialect , mlir::func::FuncDialect ,
1568
+ mlir::scf::SCFDialect , mlir::cf::ControlFlowDialect ,
1569
+ mlir::math::MathDialect , mlir::vector::VectorDialect ,
1570
+ mlir::named_tuple::NamedTupleDialect >();
1494
1571
target.addIllegalDialect <cir::CIRDialect>();
1495
1572
1496
1573
if (runAtStartHook)
0 commit comments