Skip to content

Commit 530c0ca

Browse files
committed
[CIR] Lower cir.get_member to named_tuple + memref casts
Emulate the member access through memory for now.
1 parent 5b778dd commit 530c0ca

File tree

5 files changed

+123
-15
lines changed

5 files changed

+123
-15
lines changed

clang/include/clang/CIR/LowerToMLIR.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ void populateCIRLoopToSCFConversionPatterns(mlir::RewritePatternSet &patterns,
2222

2323
mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout);
2424

25-
void runAtStartOfConvertCIRToMLIRPass(std::function<void(mlir::ConversionTarget)>);
25+
void runAtStartOfConvertCIRToMLIRPass(
26+
std::function<void(mlir::ConversionTarget)>);
2627

2728
mlir::ModuleOp
2829
lowerFromCIRToMLIRToLLVMDialect(mlir::ModuleOp theModule,

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 86 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,17 @@
2929
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
3030
#include "mlir/Dialect/Math/IR/Math.h"
3131
#include "mlir/Dialect/MemRef/IR/MemRef.h"
32+
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
3233
#include "mlir/Dialect/NamedTuple/IR/NamedTuple.h"
34+
#include "mlir/Dialect/NamedTuple/IR/NamedTupleDialect.h"
35+
#include "mlir/Dialect/NamedTuple/IR/NamedTupleTypes.h"
3336
#include "mlir/Dialect/SCF/IR/SCF.h"
3437
#include "mlir/Dialect/SCF/Transforms/Passes.h"
3538
#include "mlir/Dialect/Vector/IR/VectorOps.h"
39+
#include "mlir/IR/Attributes.h"
40+
#include "mlir/IR/BuiltinAttributes.h"
3641
#include "mlir/IR/BuiltinDialect.h"
42+
#include "mlir/IR/BuiltinOps.h"
3743
#include "mlir/IR/BuiltinTypes.h"
3844
#include "mlir/IR/Operation.h"
3945
#include "mlir/IR/Region.h"
@@ -48,12 +54,14 @@
4854
#include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
4955
#include "mlir/Target/LLVMIR/Export.h"
5056
#include "mlir/Transforms/DialectConversion.h"
57+
#include "clang/CIR/Dialect/IR/CIRDataLayout.h"
5158
#include "clang/CIR/Dialect/IR/CIRDialect.h"
5259
#include "clang/CIR/Dialect/IR/CIRTypes.h"
5360
#include "clang/CIR/LowerToLLVM.h"
5461
#include "clang/CIR/LowerToMLIR.h"
5562
#include "clang/CIR/LoweringHelpers.h"
5663
#include "clang/CIR/Passes.h"
64+
#include "llvm/ADT/ArrayRef.h"
5765
#include "llvm/ADT/STLExtras.h"
5866
#include "llvm/ADT/Sequence.h"
5967
#include "llvm/ADT/SmallVector.h"
@@ -176,7 +184,7 @@ class CIRAllocaOpLowering : public mlir::OpConversionPattern<cir::AllocaOp> {
176184
mlir::Type mlirType =
177185
convertTypeForMemory(*getTypeConverter(), adaptor.getAllocaType());
178186

179-
// FIXME: Some types can not be converted yet (e.g. struct)
187+
// FIXME: Some types can not be converted yet
180188
if (!mlirType)
181189
return mlir::LogicalResult::failure();
182190

@@ -285,6 +293,71 @@ class CIRStoreOpLowering : public mlir::OpConversionPattern<cir::StoreOp> {
285293
}
286294
};
287295

296+
// Lower cir.get_member
297+
//
298+
// clang-format off
299+
//
300+
// %5 = cir.get_member %1[1] {name = "b"} : !cir.ptr<!named_tuple.named_tuple<"s", [i32, f64, i8]>> -> !cir.ptr<!cir.double>
301+
//
302+
// to something like
303+
//
304+
// %1 = named_tuple.cast %alloca_0 : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
305+
// %c8 = arith.constant 8 : index
306+
// %view_1 = memref.view %1[%c8][] : memref<24xi8> to memref<f64>
307+
// clang-format on
308+
class CIRGetMemberOpLowering
309+
: public mlir::OpConversionPattern<cir::GetMemberOp> {
310+
cir::CIRDataLayout const &dataLayout;
311+
312+
public:
313+
using mlir::OpConversionPattern<cir::GetMemberOp>::OpConversionPattern;
314+
315+
CIRGetMemberOpLowering(const mlir::TypeConverter &typeConverter,
316+
mlir::MLIRContext *context,
317+
cir::CIRDataLayout const &dataLayout)
318+
: OpConversionPattern{typeConverter, context}, dataLayout{dataLayout} {}
319+
320+
mlir::LogicalResult
321+
matchAndRewrite(cir::GetMemberOp op, OpAdaptor adaptor,
322+
mlir::ConversionPatternRewriter &rewriter) const override {
323+
auto pointeeType = op.getAddrTy().getPointee();
324+
if (!mlir::isa<cir::StructType>(pointeeType))
325+
op.emitError("GetMemberOp only works on pointer to cir::StructType");
326+
auto structType = mlir::cast<cir::StructType>(pointeeType);
327+
// For now, just rely on the datalayout of the high-level type since the
328+
// datalayout of low-level type is not implemented yet. But since C++ is a
329+
// concrete datalayout, both datalayouts are the same.
330+
auto *structLayout = dataLayout.getStructLayout(structType);
331+
332+
// Get the lowered type: memref<!named_tuple.named_tuple<>>
333+
auto memref = mlir::cast<mlir::MemRefType>(adaptor.getAddr().getType());
334+
// Alias the memref of struct to a memref of an i8 array of the same size.
335+
const std::array linearizedSize{
336+
static_cast<std::int64_t>(dataLayout.getTypeStoreSize(structType))};
337+
auto flattenMemRef = mlir::MemRefType::get(
338+
linearizedSize, mlir::IntegerType::get(memref.getContext(), 8));
339+
// Use a special cast because normal memref cast cannot do such an extreme
340+
// cast.
341+
auto bytesMemRef = rewriter.create<mlir::named_tuple::CastOp>(
342+
op.getLoc(), mlir::TypeRange{flattenMemRef},
343+
mlir::ValueRange{adaptor.getAddr()});
344+
345+
auto memberIndex = op.getIndex();
346+
auto namedTupleType =
347+
mlir::cast<mlir::named_tuple::NamedTupleType>(memref.getElementType());
348+
// The lowered type of the element to access in the named_tuple.
349+
auto loweredMemberType = namedTupleType.getType(memberIndex);
350+
auto elementMemRefTy = mlir::MemRefType::get({}, loweredMemberType);
351+
auto offset = structLayout->getElementOffset(memberIndex);
352+
// Synthesize the byte access to right lowered type.
353+
auto byteShift =
354+
rewriter.create<mlir::arith::ConstantIndexOp>(op.getLoc(), offset);
355+
rewriter.replaceOpWithNewOp<mlir::memref::ViewOp>(
356+
op, elementMemRefTy, bytesMemRef, byteShift, mlir::ValueRange{});
357+
return mlir::LogicalResult::success();
358+
}
359+
};
360+
288361
class CIRCosOpLowering : public mlir::OpConversionPattern<cir::CosOp> {
289362
public:
290363
using OpConversionPattern<cir::CosOp>::OpConversionPattern;
@@ -1359,7 +1432,8 @@ class CIRPtrStrideOpLowering
13591432
};
13601433

13611434
void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
1362-
mlir::TypeConverter &converter) {
1435+
mlir::TypeConverter &converter,
1436+
cir::CIRDataLayout &cirDataLayout) {
13631437
patterns.add<CIRReturnLowering, CIRBrOpLowering>(patterns.getContext());
13641438

13651439
patterns.add<
@@ -1378,6 +1452,8 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
13781452
CIRIfOpLowering, CIRVectorCreateLowering, CIRVectorInsertLowering,
13791453
CIRVectorExtractLowering, CIRVectorCmpOpLowering>(converter,
13801454
patterns.getContext());
1455+
patterns.add<CIRGetMemberOpLowering>(converter, patterns.getContext(),
1456+
cirDataLayout);
13811457
}
13821458

13831459
mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
@@ -1434,7 +1510,7 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
14341510
curType = arrayType.getEltType();
14351511
}
14361512
auto elementType = converter.convertType(curType);
1437-
// FIXME: The element type might not be converted (e.g. struct)
1513+
// FIXME: The element type might not be converted
14381514
if (!elementType)
14391515
return nullptr;
14401516
return mlir::MemRefType::get(shape, elementType);
@@ -1476,20 +1552,21 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
14761552
void ConvertCIRToMLIRPass::runOnOperation() {
14771553
auto module = getOperation();
14781554
mlir::DataLayout dataLayout{module};
1555+
cir::CIRDataLayout cirDataLayout{module};
14791556
auto converter = prepareTypeConverter(dataLayout);
14801557

14811558
mlir::RewritePatternSet patterns(&getContext());
14821559

14831560
populateCIRLoopToSCFConversionPatterns(patterns, converter);
1484-
populateCIRToMLIRConversionPatterns(patterns, converter);
1561+
populateCIRToMLIRConversionPatterns(patterns, converter, cirDataLayout);
14851562

14861563
mlir::ConversionTarget target(getContext());
14871564
target.addLegalOp<mlir::ModuleOp>();
1488-
target
1489-
.addLegalDialect<mlir::affine::AffineDialect, mlir::arith::ArithDialect,
1490-
mlir::memref::MemRefDialect, mlir::func::FuncDialect,
1491-
mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect,
1492-
mlir::math::MathDialect, mlir::vector::VectorDialect>();
1565+
target.addLegalDialect<mlir::affine::AffineDialect, mlir::arith::ArithDialect,
1566+
mlir::memref::MemRefDialect, mlir::func::FuncDialect,
1567+
mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect,
1568+
mlir::math::MathDialect, mlir::vector::VectorDialect,
1569+
mlir::named_tuple::NamedTupleDialect>();
14931570
target.addIllegalDialect<cir::CIRDialect>();
14941571

14951572
if (runAtStartHook)

clang/test/CIR/Lowering/ThroughMLIR/struct.cpp

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,38 @@
33

44
struct s {
55
int a;
6-
float b;
6+
double b;
7+
char c;
78
};
8-
int main() { s v; }
9-
// CHECK: memref<!named_tuple.named_tuple<"s", [i32, f32]>>
9+
10+
int main() {
11+
s v;
12+
// CHECK: %[[ALLOCA:.+]] = memref.alloca() {alignment = 8 : i64} : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>>
13+
v.a = 7;
14+
// CHECK: %[[C_7:.+]] = arith.constant 7 : i32
15+
// CHECK: %[[I8_EQUIV_A:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
16+
// CHECK: %[[OFFSET_A:.+]] = arith.constant 0 : index
17+
// CHECK: %[[VIEW_A:.+]] = memref.view %[[I8_EQUIV_A]][%[[OFFSET_A]]][] : memref<24xi8> to memref<i32>
18+
// CHECK: memref.store %[[C_7]], %[[VIEW_A]][] : memref<i32>
19+
20+
v.b = 3.;
21+
// CHECK: %[[C_3:.+]] = arith.constant 3.000000e+00 : f64
22+
// CHECK: %[[I8_EQUIV_B:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
23+
// CHECK: %[[OFFSET_B:.+]] = arith.constant 8 : index
24+
// CHECK: %[[VIEW_B:.+]] = memref.view %[[I8_EQUIV_B]][%[[OFFSET_B]]][] : memref<24xi8> to memref<f64>
25+
// CHECK: memref.store %[[C_3]], %[[VIEW_B]][] : memref<f64>
26+
27+
v.c = 'z';
28+
// CHECK: %[[C_122:.+]] = arith.constant 122 : i8
29+
// CHECK: %[[I8_EQUIV_C:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
30+
// CHECK: %[[OFFSET_C:.+]] = arith.constant 16 : index
31+
// CHECK: %[[VIEW_C:.+]] = memref.view %[[I8_EQUIV_C]][%[[OFFSET_C]]][] : memref<24xi8> to memref<i8>
32+
// memref.store %[[C_122]], %[[VIEW_C]][] : memref<i8>
33+
34+
return v.c;
35+
// CHECK: %[[I8_EQUIV_C_1:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
36+
// CHECK: %[[OFFSET_C_1:.+]] = arith.constant 16 : index
37+
// CHECK: %[[VIEW_C_1:.+]] = memref.view %[[I8_EQUIV_C_1]][%[[OFFSET_C_1]]][] : memref<24xi8> to memref<i8>
38+
// CHECK: %[[VALUE_C:.+]] = memref.load %[[VIEW_C_1]][] : memref<i8>
39+
// CHECK: %[[VALUE_RET:.+]] = arith.extsi %[[VALUE_C]] : i8 to i32
40+
}

mlir/include/mlir/Dialect/NamedTuple/IR/NamedTuple.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#ifndef MLIR_DIALECT_NAMED_TUPLE_IR_NAMED_TUPLE_H
1414
#define MLIR_DIALECT_NAMED_TUPLE_IR_NAMED_TUPLE_H
1515

16-
//#include "mlir/IR/Dialect.h"
1716
#include "mlir/Dialect/NamedTuple/IR/NamedTupleDialect.h"
1817
#include "mlir/Dialect/NamedTuple/IR/NamedTupleTypes.h"
1918

mlir/include/mlir/InitAllDialects.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@
6060
#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
6161
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
6262
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
63-
#include "mlir/Dialect/NamedTuple/IR/NamedTuple.h"
6463
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
64+
#include "mlir/Dialect/NamedTuple/IR/NamedTuple.h"
6565
#include "mlir/Dialect/OpenACC/OpenACC.h"
6666
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
6767
#include "mlir/Dialect/PDL/IR/PDL.h"

0 commit comments

Comments
 (0)