2929#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
3030#include " mlir/Dialect/Math/IR/Math.h"
3131#include " mlir/Dialect/MemRef/IR/MemRef.h"
32+ #include " mlir/Dialect/MemRef/Utils/MemRefUtils.h"
3233#include " mlir/Dialect/NamedTuple/IR/NamedTuple.h"
34+ #include " mlir/Dialect/NamedTuple/IR/NamedTupleDialect.h"
35+ #include " mlir/Dialect/NamedTuple/IR/NamedTupleTypes.h"
3336#include " mlir/Dialect/SCF/IR/SCF.h"
3437#include " mlir/Dialect/SCF/Transforms/Passes.h"
3538#include " mlir/Dialect/Vector/IR/VectorOps.h"
39+ #include " mlir/IR/Attributes.h"
40+ #include " mlir/IR/BuiltinAttributes.h"
3641#include " mlir/IR/BuiltinDialect.h"
42+ #include " mlir/IR/BuiltinOps.h"
3743#include " mlir/IR/BuiltinTypes.h"
3844#include " mlir/IR/Operation.h"
3945#include " mlir/IR/Region.h"
4854#include " mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
4955#include " mlir/Target/LLVMIR/Export.h"
5056#include " mlir/Transforms/DialectConversion.h"
57+ #include " clang/CIR/Dialect/IR/CIRDataLayout.h"
5158#include " clang/CIR/Dialect/IR/CIRDialect.h"
5259#include " clang/CIR/Dialect/IR/CIRTypes.h"
5360#include " clang/CIR/LowerToMLIR.h"
5461#include " clang/CIR/LoweringHelpers.h"
5562#include " clang/CIR/Passes.h"
63+ #include " llvm/ADT/ArrayRef.h"
5664#include " llvm/ADT/STLExtras.h"
5765#include " llvm/ADT/Sequence.h"
5866#include " llvm/ADT/SmallVector.h"
@@ -175,7 +183,7 @@ class CIRAllocaOpLowering : public mlir::OpConversionPattern<cir::AllocaOp> {
175183 mlir::Type mlirType =
176184 convertTypeForMemory (*getTypeConverter (), adaptor.getAllocaType ());
177185
178- // FIXME: Some types can not be converted yet (e.g. struct)
186+ // FIXME: Some types can not be converted yet
179187 if (!mlirType)
180188 return mlir::LogicalResult::failure ();
181189
@@ -277,6 +285,71 @@ class CIRStoreOpLowering : public mlir::OpConversionPattern<cir::StoreOp> {
277285 }
278286};
279287
288+ // Lower cir.get_member
289+ //
290+ // clang-format off
291+ //
292+ // %5 = cir.get_member %1[1] {name = "b"} : !cir.ptr<!named_tuple.named_tuple<"s", [i32, f64, i8]>> -> !cir.ptr<!cir.double>
293+ //
294+ // to something like
295+ //
296+ // %1 = named_tuple.cast %alloca_0 : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
297+ // %c8 = arith.constant 8 : index
298+ // %view_1 = memref.view %1[%c8][] : memref<24xi8> to memref<f64>
299+ // clang-format on
300+ class CIRGetMemberOpLowering
301+ : public mlir::OpConversionPattern<cir::GetMemberOp> {
302+ cir::CIRDataLayout const &dataLayout;
303+
304+ public:
305+ using mlir::OpConversionPattern<cir::GetMemberOp>::OpConversionPattern;
306+
307+ CIRGetMemberOpLowering (const mlir::TypeConverter &typeConverter,
308+ mlir::MLIRContext *context,
309+ cir::CIRDataLayout const &dataLayout)
310+ : OpConversionPattern{typeConverter, context}, dataLayout{dataLayout} {}
311+
312+ mlir::LogicalResult
313+ matchAndRewrite (cir::GetMemberOp op, OpAdaptor adaptor,
314+ mlir::ConversionPatternRewriter &rewriter) const override {
315+ auto pointeeType = op.getAddrTy ().getPointee ();
316+ if (!mlir::isa<cir::StructType>(pointeeType))
317+ op.emitError (" GetMemberOp only works on pointer to cir::StructType" );
318+ auto structType = mlir::cast<cir::StructType>(pointeeType);
319+ // For now, just rely on the datalayout of the high-level type since the
320+ // datalayout of low-level type is not implemented yet. But since C++ is a
321+ // concrete datalayout, both datalayouts are the same.
322+ auto *structLayout = dataLayout.getStructLayout (structType);
323+
324+ // Get the lowered type: memref<!named_tuple.named_tuple<>>
325+ auto memref = mlir::cast<mlir::MemRefType>(adaptor.getAddr ().getType ());
326+ // Alias the memref of struct to a memref of an i8 array of the same size.
327+ const std::array linearizedSize{
328+ static_cast <std::int64_t >(dataLayout.getTypeStoreSize (structType))};
329+ auto flattenMemRef = mlir::MemRefType::get (
330+ linearizedSize, mlir::IntegerType::get (memref.getContext (), 8 ));
331+ // Use a special cast because normal memref cast cannot do such an extreme
332+ // cast.
333+ auto bytesMemRef = rewriter.create <mlir::named_tuple::CastOp>(
334+ op.getLoc (), mlir::TypeRange{flattenMemRef},
335+ mlir::ValueRange{adaptor.getAddr ()});
336+
337+ auto memberIndex = op.getIndex ();
338+ auto namedTupleType =
339+ mlir::cast<mlir::named_tuple::NamedTupleType>(memref.getElementType ());
340+ // The lowered type of the element to access in the named_tuple.
341+ auto loweredMemberType = namedTupleType.getType (memberIndex);
342+ auto elementMemRefTy = mlir::MemRefType::get ({}, loweredMemberType);
343+ auto offset = structLayout->getElementOffset (memberIndex);
344+ // Synthesize the byte access to right lowered type.
345+ auto byteShift =
346+ rewriter.create <mlir::arith::ConstantIndexOp>(op.getLoc (), offset);
347+ rewriter.replaceOpWithNewOp <mlir::memref::ViewOp>(
348+ op, elementMemRefTy, bytesMemRef, byteShift, mlir::ValueRange{});
349+ return mlir::LogicalResult::success ();
350+ }
351+ };
352+
280353class CIRCosOpLowering : public mlir ::OpConversionPattern<cir::CosOp> {
281354public:
282355 using OpConversionPattern<cir::CosOp>::OpConversionPattern;
@@ -1353,7 +1426,8 @@ class CIRPtrStrideOpLowering
13531426};
13541427
13551428void populateCIRToMLIRConversionPatterns (mlir::RewritePatternSet &patterns,
1356- mlir::TypeConverter &converter) {
1429+ mlir::TypeConverter &converter,
1430+ cir::CIRDataLayout &cirDataLayout) {
13571431 patterns.add <CIRReturnLowering, CIRBrOpLowering>(patterns.getContext ());
13581432
13591433 patterns.add <
@@ -1372,6 +1446,8 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
13721446 CIRIfOpLowering, CIRVectorCreateLowering, CIRVectorInsertLowering,
13731447 CIRVectorExtractLowering, CIRVectorCmpOpLowering>(converter,
13741448 patterns.getContext ());
1449+ patterns.add <CIRGetMemberOpLowering>(converter, patterns.getContext (),
1450+ cirDataLayout);
13751451}
13761452
13771453mlir::TypeConverter prepareTypeConverter (mlir::DataLayout &dataLayout) {
@@ -1428,7 +1504,7 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
14281504 curType = arrayType.getEltType ();
14291505 }
14301506 auto elementType = converter.convertType (curType);
1431- // FIXME: The element type might not be converted (e.g. struct)
1507+ // FIXME: The element type might not be converted
14321508 if (!elementType)
14331509 return nullptr ;
14341510 return mlir::MemRefType::get (shape, elementType);
@@ -1470,20 +1546,21 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
14701546void ConvertCIRToMLIRPass::runOnOperation () {
14711547 auto module = getOperation ();
14721548 mlir::DataLayout dataLayout{module };
1549+ cir::CIRDataLayout cirDataLayout{module };
14731550 auto converter = prepareTypeConverter (dataLayout);
14741551
14751552 mlir::RewritePatternSet patterns (&getContext ());
14761553
14771554 populateCIRLoopToSCFConversionPatterns (patterns, converter);
1478- populateCIRToMLIRConversionPatterns (patterns, converter);
1555+ populateCIRToMLIRConversionPatterns (patterns, converter, cirDataLayout );
14791556
14801557 mlir::ConversionTarget target (getContext ());
14811558 target.addLegalOp <mlir::ModuleOp>();
1482- target
1483- . addLegalDialect < mlir::affine::AffineDialect , mlir::arith::ArithDialect ,
1484- mlir::memref::MemRefDialect , mlir::func::FuncDialect ,
1485- mlir::scf::SCFDialect , mlir::cf::ControlFlowDialect ,
1486- mlir::math::MathDialect, mlir::vector::VectorDialect >();
1559+ target. addLegalDialect <mlir::affine::AffineDialect, mlir::arith::ArithDialect,
1560+ mlir::memref::MemRefDialect , mlir::func::FuncDialect ,
1561+ mlir::scf::SCFDialect , mlir::cf::ControlFlowDialect ,
1562+ mlir::math::MathDialect , mlir::vector::VectorDialect ,
1563+ mlir::named_tuple::NamedTupleDialect >();
14871564 target.addIllegalDialect <cir::CIRDialect>();
14881565
14891566 if (runAtStartHook)
0 commit comments