Skip to content

Commit cc2b218

Browse files
committed
[CIR] Lower cir.get_member to named_tuple + memref casts
Emulate the member access through memory for now.
1 parent 6ff6c48 commit cc2b218

File tree

5 files changed

+123
-15
lines changed

5 files changed

+123
-15
lines changed

clang/include/clang/CIR/LowerToMLIR.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ namespace cir {
2020
void populateCIRLoopToSCFConversionPatterns(mlir::RewritePatternSet &patterns,
2121
mlir::TypeConverter &converter);
2222
mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout);
23-
void runAtStartOfConvertCIRToMLIRPass(std::function<void(mlir::ConversionTarget)>);
23+
void runAtStartOfConvertCIRToMLIRPass(
24+
std::function<void(mlir::ConversionTarget)>);
2425
} // namespace cir
2526

2627
#endif // CLANG_CIR_LOWERTOMLIR_H_

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 86 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,17 @@
2929
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
3030
#include "mlir/Dialect/Math/IR/Math.h"
3131
#include "mlir/Dialect/MemRef/IR/MemRef.h"
32+
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
3233
#include "mlir/Dialect/NamedTuple/IR/NamedTuple.h"
34+
#include "mlir/Dialect/NamedTuple/IR/NamedTupleDialect.h"
35+
#include "mlir/Dialect/NamedTuple/IR/NamedTupleTypes.h"
3336
#include "mlir/Dialect/SCF/IR/SCF.h"
3437
#include "mlir/Dialect/SCF/Transforms/Passes.h"
3538
#include "mlir/Dialect/Vector/IR/VectorOps.h"
39+
#include "mlir/IR/Attributes.h"
40+
#include "mlir/IR/BuiltinAttributes.h"
3641
#include "mlir/IR/BuiltinDialect.h"
42+
#include "mlir/IR/BuiltinOps.h"
3743
#include "mlir/IR/BuiltinTypes.h"
3844
#include "mlir/IR/Operation.h"
3945
#include "mlir/IR/Region.h"
@@ -48,11 +54,13 @@
4854
#include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
4955
#include "mlir/Target/LLVMIR/Export.h"
5056
#include "mlir/Transforms/DialectConversion.h"
57+
#include "clang/CIR/Dialect/IR/CIRDataLayout.h"
5158
#include "clang/CIR/Dialect/IR/CIRDialect.h"
5259
#include "clang/CIR/Dialect/IR/CIRTypes.h"
5360
#include "clang/CIR/LowerToMLIR.h"
5461
#include "clang/CIR/LoweringHelpers.h"
5562
#include "clang/CIR/Passes.h"
63+
#include "llvm/ADT/ArrayRef.h"
5664
#include "llvm/ADT/STLExtras.h"
5765
#include "llvm/ADT/Sequence.h"
5866
#include "llvm/ADT/SmallVector.h"
@@ -175,7 +183,7 @@ class CIRAllocaOpLowering : public mlir::OpConversionPattern<cir::AllocaOp> {
175183
mlir::Type mlirType =
176184
convertTypeForMemory(*getTypeConverter(), adaptor.getAllocaType());
177185

178-
// FIXME: Some types can not be converted yet (e.g. struct)
186+
// FIXME: Some types can not be converted yet
179187
if (!mlirType)
180188
return mlir::LogicalResult::failure();
181189

@@ -284,6 +292,71 @@ class CIRStoreOpLowering : public mlir::OpConversionPattern<cir::StoreOp> {
284292
}
285293
};
286294

295+
// Lower cir.get_member
296+
//
297+
// clang-format off
298+
//
299+
// %5 = cir.get_member %1[1] {name = "b"} : !cir.ptr<!named_tuple.named_tuple<"s", [i32, f64, i8]>> -> !cir.ptr<!cir.double>
300+
//
301+
// to something like
302+
//
303+
// %1 = named_tuple.cast %alloca_0 : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
304+
// %c8 = arith.constant 8 : index
305+
// %view_1 = memref.view %1[%c8][] : memref<24xi8> to memref<f64>
306+
// clang-format on
307+
class CIRGetMemberOpLowering
308+
: public mlir::OpConversionPattern<cir::GetMemberOp> {
309+
cir::CIRDataLayout const &dataLayout;
310+
311+
public:
312+
using mlir::OpConversionPattern<cir::GetMemberOp>::OpConversionPattern;
313+
314+
CIRGetMemberOpLowering(const mlir::TypeConverter &typeConverter,
315+
mlir::MLIRContext *context,
316+
cir::CIRDataLayout const &dataLayout)
317+
: OpConversionPattern{typeConverter, context}, dataLayout{dataLayout} {}
318+
319+
mlir::LogicalResult
320+
matchAndRewrite(cir::GetMemberOp op, OpAdaptor adaptor,
321+
mlir::ConversionPatternRewriter &rewriter) const override {
322+
auto pointeeType = op.getAddrTy().getPointee();
323+
if (!mlir::isa<cir::StructType>(pointeeType))
324+
op.emitError("GetMemberOp only works on pointer to cir::StructType");
325+
auto structType = mlir::cast<cir::StructType>(pointeeType);
326+
// For now, just rely on the datalayout of the high-level type since the
327+
// datalayout of low-level type is not implemented yet. But since C++ is a
328+
// concrete datalayout, both datalayouts are the same.
329+
auto *structLayout = dataLayout.getStructLayout(structType);
330+
331+
// Get the lowered type: memref<!named_tuple.named_tuple<>>
332+
auto memref = mlir::cast<mlir::MemRefType>(adaptor.getAddr().getType());
333+
// Alias the memref of struct to a memref of an i8 array of the same size.
334+
const std::array linearizedSize{
335+
static_cast<std::int64_t>(dataLayout.getTypeStoreSize(structType))};
336+
auto flattenMemRef = mlir::MemRefType::get(
337+
linearizedSize, mlir::IntegerType::get(memref.getContext(), 8));
338+
// Use a special cast because normal memref cast cannot do such an extreme
339+
// cast.
340+
auto bytesMemRef = rewriter.create<mlir::named_tuple::CastOp>(
341+
op.getLoc(), mlir::TypeRange{flattenMemRef},
342+
mlir::ValueRange{adaptor.getAddr()});
343+
344+
auto memberIndex = op.getIndex();
345+
auto namedTupleType =
346+
mlir::cast<mlir::named_tuple::NamedTupleType>(memref.getElementType());
347+
// The lowered type of the element to access in the named_tuple.
348+
auto loweredMemberType = namedTupleType.getType(memberIndex);
349+
auto elementMemRefTy = mlir::MemRefType::get({}, loweredMemberType);
350+
auto offset = structLayout->getElementOffset(memberIndex);
351+
// Synthesize the byte access to right lowered type.
352+
auto byteShift =
353+
rewriter.create<mlir::arith::ConstantIndexOp>(op.getLoc(), offset);
354+
rewriter.replaceOpWithNewOp<mlir::memref::ViewOp>(
355+
op, elementMemRefTy, bytesMemRef, byteShift, mlir::ValueRange{});
356+
return mlir::LogicalResult::success();
357+
}
358+
};
359+
287360
class CIRCosOpLowering : public mlir::OpConversionPattern<cir::CosOp> {
288361
public:
289362
using OpConversionPattern<cir::CosOp>::OpConversionPattern;
@@ -1360,7 +1433,8 @@ class CIRPtrStrideOpLowering
13601433
};
13611434

13621435
void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
1363-
mlir::TypeConverter &converter) {
1436+
mlir::TypeConverter &converter,
1437+
cir::CIRDataLayout &cirDataLayout) {
13641438
patterns.add<CIRReturnLowering, CIRBrOpLowering>(patterns.getContext());
13651439

13661440
patterns.add<
@@ -1379,6 +1453,8 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
13791453
CIRIfOpLowering, CIRVectorCreateLowering, CIRVectorInsertLowering,
13801454
CIRVectorExtractLowering, CIRVectorCmpOpLowering>(converter,
13811455
patterns.getContext());
1456+
patterns.add<CIRGetMemberOpLowering>(converter, patterns.getContext(),
1457+
cirDataLayout);
13821458
}
13831459

13841460
mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
@@ -1435,7 +1511,7 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
14351511
curType = arrayType.getEltType();
14361512
}
14371513
auto elementType = converter.convertType(curType);
1438-
// FIXME: The element type might not be converted (e.g. struct)
1514+
// FIXME: The element type might not be converted
14391515
if (!elementType)
14401516
return nullptr;
14411517
return mlir::MemRefType::get(shape, elementType);
@@ -1477,20 +1553,21 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
14771553
void ConvertCIRToMLIRPass::runOnOperation() {
14781554
auto module = getOperation();
14791555
mlir::DataLayout dataLayout{module};
1556+
cir::CIRDataLayout cirDataLayout{module};
14801557
auto converter = prepareTypeConverter(dataLayout);
14811558

14821559
mlir::RewritePatternSet patterns(&getContext());
14831560

14841561
populateCIRLoopToSCFConversionPatterns(patterns, converter);
1485-
populateCIRToMLIRConversionPatterns(patterns, converter);
1562+
populateCIRToMLIRConversionPatterns(patterns, converter, cirDataLayout);
14861563

14871564
mlir::ConversionTarget target(getContext());
14881565
target.addLegalOp<mlir::ModuleOp>();
1489-
target
1490-
.addLegalDialect<mlir::affine::AffineDialect, mlir::arith::ArithDialect,
1491-
mlir::memref::MemRefDialect, mlir::func::FuncDialect,
1492-
mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect,
1493-
mlir::math::MathDialect, mlir::vector::VectorDialect>();
1566+
target.addLegalDialect<mlir::affine::AffineDialect, mlir::arith::ArithDialect,
1567+
mlir::memref::MemRefDialect, mlir::func::FuncDialect,
1568+
mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect,
1569+
mlir::math::MathDialect, mlir::vector::VectorDialect,
1570+
mlir::named_tuple::NamedTupleDialect>();
14941571
target.addIllegalDialect<cir::CIRDialect>();
14951572

14961573
if (runAtStartHook)

clang/test/CIR/Lowering/ThroughMLIR/struct.cpp

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,38 @@
33

44
struct s {
55
int a;
6-
float b;
6+
double b;
7+
char c;
78
};
8-
int main() { s v; }
9-
// CHECK: memref<!named_tuple.named_tuple<"s", [i32, f32]>>
9+
10+
int main() {
11+
s v;
12+
// CHECK: %[[ALLOCA:.+]] = memref.alloca() {alignment = 8 : i64} : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>>
13+
v.a = 7;
14+
// CHECK: %[[C_7:.+]] = arith.constant 7 : i32
15+
// CHECK: %[[I8_EQUIV_A:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
16+
// CHECK: %[[OFFSET_A:.+]] = arith.constant 0 : index
17+
// CHECK: %[[VIEW_A:.+]] = memref.view %[[I8_EQUIV_A]][%[[OFFSET_A]]][] : memref<24xi8> to memref<i32>
18+
// CHECK: memref.store %[[C_7]], %[[VIEW_A]][] : memref<i32>
19+
20+
v.b = 3.;
21+
// CHECK: %[[C_3:.+]] = arith.constant 3.000000e+00 : f64
22+
// CHECK: %[[I8_EQUIV_B:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
23+
// CHECK: %[[OFFSET_B:.+]] = arith.constant 8 : index
24+
// CHECK: %[[VIEW_B:.+]] = memref.view %[[I8_EQUIV_B]][%[[OFFSET_B]]][] : memref<24xi8> to memref<f64>
25+
// CHECK: memref.store %[[C_3]], %[[VIEW_B]][] : memref<f64>
26+
27+
v.c = 'z';
28+
// CHECK: %[[C_122:.+]] = arith.constant 122 : i8
29+
// CHECK: %[[I8_EQUIV_C:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
30+
// CHECK: %[[OFFSET_C:.+]] = arith.constant 16 : index
31+
// CHECK: %[[VIEW_C:.+]] = memref.view %[[I8_EQUIV_C]][%[[OFFSET_C]]][] : memref<24xi8> to memref<i8>
32+
// memref.store %[[C_122]], %[[VIEW_C]][] : memref<i8>
33+
34+
return v.c;
35+
// CHECK: %[[I8_EQUIV_C_1:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
36+
// CHECK: %[[OFFSET_C_1:.+]] = arith.constant 16 : index
37+
// CHECK: %[[VIEW_C_1:.+]] = memref.view %[[I8_EQUIV_C_1]][%[[OFFSET_C_1]]][] : memref<24xi8> to memref<i8>
38+
// CHECK: %[[VALUE_C:.+]] = memref.load %[[VIEW_C_1]][] : memref<i8>
39+
// CHECK: %[[VALUE_RET:.+]] = arith.extsi %[[VALUE_C]] : i8 to i32
40+
}

mlir/include/mlir/Dialect/NamedTuple/IR/NamedTuple.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#ifndef MLIR_DIALECT_NAMED_TUPLE_IR_NAMED_TUPLE_H
1414
#define MLIR_DIALECT_NAMED_TUPLE_IR_NAMED_TUPLE_H
1515

16-
//#include "mlir/IR/Dialect.h"
1716
#include "mlir/Dialect/NamedTuple/IR/NamedTupleDialect.h"
1817
#include "mlir/Dialect/NamedTuple/IR/NamedTupleTypes.h"
1918

mlir/include/mlir/InitAllDialects.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@
5959
#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
6060
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
6161
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
62-
#include "mlir/Dialect/NamedTuple/IR/NamedTuple.h"
6362
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
63+
#include "mlir/Dialect/NamedTuple/IR/NamedTuple.h"
6464
#include "mlir/Dialect/OpenACC/OpenACC.h"
6565
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
6666
#include "mlir/Dialect/PDL/IR/PDL.h"

0 commit comments

Comments
 (0)