Skip to content

Commit 12626d1

Browse files
committed
[CIR] Lower cir.get_member to named_tuple + memref casts
Emulate the member access through memory for now.
1 parent a09a2e8 commit 12626d1

File tree

5 files changed

+119
-10
lines changed

5 files changed

+119
-10
lines changed

clang/include/clang/CIR/LowerToMLIR.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ void populateCIRLoopToSCFConversionPatterns(mlir::RewritePatternSet &patterns,
2222

2323
mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout);
2424

25-
void runAtStartOfConvertCIRToMLIRPass(std::function<void(mlir::ConversionTarget)>);
25+
void runAtStartOfConvertCIRToMLIRPass(
26+
std::function<void(mlir::ConversionTarget)>);
2627

2728
mlir::ModuleOp
2829
lowerFromCIRToMLIRToLLVMDialect(mlir::ModuleOp theModule,

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 82 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,16 @@
2727
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
2828
#include "mlir/Dialect/Math/IR/Math.h"
2929
#include "mlir/Dialect/MemRef/IR/MemRef.h"
30+
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
3031
#include "mlir/Dialect/NamedTuple/IR/NamedTuple.h"
32+
#include "mlir/Dialect/NamedTuple/IR/NamedTupleDialect.h"
33+
#include "mlir/Dialect/NamedTuple/IR/NamedTupleTypes.h"
3134
#include "mlir/Dialect/SCF/IR/SCF.h"
3235
#include "mlir/Dialect/Vector/IR/VectorOps.h"
36+
#include "mlir/IR/Attributes.h"
37+
#include "mlir/IR/BuiltinAttributes.h"
3338
#include "mlir/IR/BuiltinDialect.h"
39+
#include "mlir/IR/BuiltinOps.h"
3440
#include "mlir/IR/BuiltinTypes.h"
3541
#include "mlir/IR/Operation.h"
3642
#include "mlir/IR/Region.h"
@@ -45,12 +51,14 @@
4551
#include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
4652
#include "mlir/Target/LLVMIR/Export.h"
4753
#include "mlir/Transforms/DialectConversion.h"
54+
#include "clang/CIR/Dialect/IR/CIRDataLayout.h"
4855
#include "clang/CIR/Dialect/IR/CIRDialect.h"
4956
#include "clang/CIR/Dialect/IR/CIRTypes.h"
5057
#include "clang/CIR/LowerToLLVM.h"
5158
#include "clang/CIR/LowerToMLIR.h"
5259
#include "clang/CIR/LoweringHelpers.h"
5360
#include "clang/CIR/Passes.h"
61+
#include "llvm/ADT/ArrayRef.h"
5462
#include "llvm/ADT/STLExtras.h"
5563
#include "llvm/ADT/SmallVector.h"
5664
#include "llvm/ADT/TypeSwitch.h"
@@ -172,7 +180,7 @@ class CIRAllocaOpLowering : public mlir::OpConversionPattern<cir::AllocaOp> {
172180
mlir::Type mlirType =
173181
convertTypeForMemory(*getTypeConverter(), adaptor.getAllocaType());
174182

175-
// FIXME: Some types can not be converted yet (e.g. struct)
183+
// FIXME: Some types can not be converted yet
176184
if (!mlirType)
177185
return mlir::LogicalResult::failure();
178186

@@ -281,6 +289,71 @@ class CIRStoreOpLowering : public mlir::OpConversionPattern<cir::StoreOp> {
281289
}
282290
};
283291

292+
// Lower cir.get_member
293+
//
294+
// clang-format off
295+
//
296+
// %5 = cir.get_member %1[1] {name = "b"} : !cir.ptr<!named_tuple.named_tuple<"s", [i32, f64, i8]>> -> !cir.ptr<!cir.double>
297+
//
298+
// to something like
299+
//
300+
// %1 = named_tuple.cast %alloca_0 : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
301+
// %c8 = arith.constant 8 : index
302+
// %view_1 = memref.view %1[%c8][] : memref<24xi8> to memref<f64>
303+
// clang-format on
304+
class CIRGetMemberOpLowering
305+
: public mlir::OpConversionPattern<cir::GetMemberOp> {
306+
cir::CIRDataLayout const &dataLayout;
307+
308+
public:
309+
using mlir::OpConversionPattern<cir::GetMemberOp>::OpConversionPattern;
310+
311+
CIRGetMemberOpLowering(const mlir::TypeConverter &typeConverter,
312+
mlir::MLIRContext *context,
313+
cir::CIRDataLayout const &dataLayout)
314+
: OpConversionPattern{typeConverter, context}, dataLayout{dataLayout} {}
315+
316+
mlir::LogicalResult
317+
matchAndRewrite(cir::GetMemberOp op, OpAdaptor adaptor,
318+
mlir::ConversionPatternRewriter &rewriter) const override {
319+
auto pointeeType = op.getAddrTy().getPointee();
320+
if (!mlir::isa<cir::StructType>(pointeeType))
321+
op.emitError("GetMemberOp only works on pointer to cir::StructType");
322+
auto structType = mlir::cast<cir::StructType>(pointeeType);
323+
// For now, just rely on the datalayout of the high-level type since the
324+
// datalayout of low-level type is not implemented yet. But since C++ is a
325+
// concrete datalayout, both datalayouts are the same.
326+
auto *structLayout = dataLayout.getStructLayout(structType);
327+
328+
// Get the lowered type: memref<!named_tuple.named_tuple<>>
329+
auto memref = mlir::cast<mlir::MemRefType>(adaptor.getAddr().getType());
330+
// Alias the memref of struct to a memref of an i8 array of the same size.
331+
const std::array linearizedSize{
332+
static_cast<std::int64_t>(dataLayout.getTypeStoreSize(structType))};
333+
auto flattenMemRef = mlir::MemRefType::get(
334+
linearizedSize, mlir::IntegerType::get(memref.getContext(), 8));
335+
// Use a special cast because normal memref cast cannot do such an extreme
336+
// cast.
337+
auto bytesMemRef = rewriter.create<mlir::named_tuple::CastOp>(
338+
op.getLoc(), mlir::TypeRange{flattenMemRef},
339+
mlir::ValueRange{adaptor.getAddr()});
340+
341+
auto memberIndex = op.getIndex();
342+
auto namedTupleType =
343+
mlir::cast<mlir::named_tuple::NamedTupleType>(memref.getElementType());
344+
// The lowered type of the element to access in the named_tuple.
345+
auto loweredMemberType = namedTupleType.getType(memberIndex);
346+
auto elementMemRefTy = mlir::MemRefType::get({}, loweredMemberType);
347+
auto offset = structLayout->getElementOffset(memberIndex);
348+
// Synthesize the byte access to right lowered type.
349+
auto byteShift =
350+
rewriter.create<mlir::arith::ConstantIndexOp>(op.getLoc(), offset);
351+
rewriter.replaceOpWithNewOp<mlir::memref::ViewOp>(
352+
op, elementMemRefTy, bytesMemRef, byteShift, mlir::ValueRange{});
353+
return mlir::LogicalResult::success();
354+
}
355+
};
356+
284357
/// Converts CIR unary math ops (e.g., cir::SinOp) to their MLIR equivalents
285358
/// (e.g., math::SinOp) using a generic template to avoid redundant boilerplate
286359
/// matchAndRewrite definitions.
@@ -1277,7 +1350,8 @@ class CIRTrapOpLowering : public mlir::OpConversionPattern<cir::TrapOp> {
12771350
};
12781351

12791352
void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
1280-
mlir::TypeConverter &converter) {
1353+
mlir::TypeConverter &converter,
1354+
cir::CIRDataLayout &cirDataLayout) {
12811355
patterns.add<CIRReturnLowering, CIRBrOpLowering>(patterns.getContext());
12821356

12831357
patterns
@@ -1298,6 +1372,8 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
12981372
CIRVectorExtractLowering, CIRVectorCmpOpLowering, CIRACosOpLowering,
12991373
CIRASinOpLowering, CIRUnreachableOpLowering, CIRTanOpLowering,
13001374
CIRTrapOpLowering>(converter, patterns.getContext());
1375+
patterns.add<CIRGetMemberOpLowering>(converter, patterns.getContext(),
1376+
cirDataLayout);
13011377
}
13021378

13031379
mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
@@ -1354,7 +1430,7 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
13541430
curType = arrayType.getEltType();
13551431
}
13561432
auto elementType = converter.convertType(curType);
1357-
// FIXME: The element type might not be converted (e.g. struct)
1433+
// FIXME: The element type might not be converted
13581434
if (!elementType)
13591435
return nullptr;
13601436
return mlir::MemRefType::get(shape, elementType);
@@ -1396,19 +1472,21 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
13961472
void ConvertCIRToMLIRPass::runOnOperation() {
13971473
auto module = getOperation();
13981474
mlir::DataLayout dataLayout{module};
1475+
cir::CIRDataLayout cirDataLayout{module};
13991476
auto converter = prepareTypeConverter(dataLayout);
14001477

14011478
mlir::RewritePatternSet patterns(&getContext());
14021479

14031480
populateCIRLoopToSCFConversionPatterns(patterns, converter);
1404-
populateCIRToMLIRConversionPatterns(patterns, converter);
1481+
populateCIRToMLIRConversionPatterns(patterns, converter, cirDataLayout);
14051482

14061483
mlir::ConversionTarget target(getContext());
14071484
target.addLegalOp<mlir::ModuleOp>();
14081485
target.addLegalDialect<mlir::affine::AffineDialect, mlir::arith::ArithDialect,
14091486
mlir::memref::MemRefDialect, mlir::func::FuncDialect,
14101487
mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect,
14111488
mlir::math::MathDialect, mlir::vector::VectorDialect,
1489+
mlir::named_tuple::NamedTupleDialect,
14121490
mlir::LLVM::LLVMDialect>();
14131491
target.addIllegalDialect<cir::CIRDialect>();
14141492

clang/test/CIR/Lowering/ThroughMLIR/struct.cpp

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,38 @@
33

44
struct s {
55
int a;
6-
float b;
6+
double b;
7+
char c;
78
};
8-
int main() { s v; }
9-
// CHECK: memref<!named_tuple.named_tuple<"s", [i32, f32]>>
9+
10+
int main() {
11+
s v;
12+
// CHECK: %[[ALLOCA:.+]] = memref.alloca() {alignment = 8 : i64} : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>>
13+
v.a = 7;
14+
// CHECK: %[[C_7:.+]] = arith.constant 7 : i32
15+
// CHECK: %[[I8_EQUIV_A:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
16+
// CHECK: %[[OFFSET_A:.+]] = arith.constant 0 : index
17+
// CHECK: %[[VIEW_A:.+]] = memref.view %[[I8_EQUIV_A]][%[[OFFSET_A]]][] : memref<24xi8> to memref<i32>
18+
// CHECK: memref.store %[[C_7]], %[[VIEW_A]][] : memref<i32>
19+
20+
v.b = 3.;
21+
// CHECK: %[[C_3:.+]] = arith.constant 3.000000e+00 : f64
22+
// CHECK: %[[I8_EQUIV_B:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
23+
// CHECK: %[[OFFSET_B:.+]] = arith.constant 8 : index
24+
// CHECK: %[[VIEW_B:.+]] = memref.view %[[I8_EQUIV_B]][%[[OFFSET_B]]][] : memref<24xi8> to memref<f64>
25+
// CHECK: memref.store %[[C_3]], %[[VIEW_B]][] : memref<f64>
26+
27+
v.c = 'z';
28+
// CHECK: %[[C_122:.+]] = arith.constant 122 : i8
29+
// CHECK: %[[I8_EQUIV_C:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
30+
// CHECK: %[[OFFSET_C:.+]] = arith.constant 16 : index
31+
// CHECK: %[[VIEW_C:.+]] = memref.view %[[I8_EQUIV_C]][%[[OFFSET_C]]][] : memref<24xi8> to memref<i8>
32+
// memref.store %[[C_122]], %[[VIEW_C]][] : memref<i8>
33+
34+
return v.c;
35+
// CHECK: %[[I8_EQUIV_C_1:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
36+
// CHECK: %[[OFFSET_C_1:.+]] = arith.constant 16 : index
37+
// CHECK: %[[VIEW_C_1:.+]] = memref.view %[[I8_EQUIV_C_1]][%[[OFFSET_C_1]]][] : memref<24xi8> to memref<i8>
38+
// CHECK: %[[VALUE_C:.+]] = memref.load %[[VIEW_C_1]][] : memref<i8>
39+
// CHECK: %[[VALUE_RET:.+]] = arith.extsi %[[VALUE_C]] : i8 to i32
40+
}

mlir/include/mlir/Dialect/NamedTuple/IR/NamedTuple.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#ifndef MLIR_DIALECT_NAMED_TUPLE_IR_NAMED_TUPLE_H
1414
#define MLIR_DIALECT_NAMED_TUPLE_IR_NAMED_TUPLE_H
1515

16-
//#include "mlir/IR/Dialect.h"
1716
#include "mlir/Dialect/NamedTuple/IR/NamedTupleDialect.h"
1817
#include "mlir/Dialect/NamedTuple/IR/NamedTupleTypes.h"
1918

mlir/include/mlir/InitAllDialects.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@
6060
#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
6161
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
6262
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
63-
#include "mlir/Dialect/NamedTuple/IR/NamedTuple.h"
6463
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
64+
#include "mlir/Dialect/NamedTuple/IR/NamedTuple.h"
6565
#include "mlir/Dialect/OpenACC/OpenACC.h"
6666
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
6767
#include "mlir/Dialect/PDL/IR/PDL.h"

0 commit comments

Comments
 (0)