Skip to content

Commit 352959f

Browse files
committed
[CIR][WIP] Add ABI lowering pass
This patch adds a new pass cir-abi-lowering to the CIR dialect. This pass runs before the CallConvLowering pass, and it expands all ABI-dependent types and operations inside a function to their ABI-independent equivalences according to the ABI specification. This patch also moves the lowering code of the following types and operations from the LLVM lowering conversion to the new pass: - The pointer-to-data-member type `cir.data_member`; - The pointer-to-member-function type `cir.method`; - All operations working on operands of the above types.
1 parent a26ebb4 commit 352959f

File tree

7 files changed

+340
-260
lines changed

7 files changed

+340
-260
lines changed

clang/include/clang/CIR/Dialect/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ std::unique_ptr<Pass> createFlattenCFGPass();
4040
std::unique_ptr<Pass> createHoistAllocasPass();
4141
std::unique_ptr<Pass> createGotoSolverPass();
4242

43+
/// Create a pass to expand ABI-dependent types and operations.
44+
std::unique_ptr<Pass> createABILoweringPass();
45+
4346
/// Create a pass to lower ABI-independent function definitions/calls.
4447
std::unique_ptr<Pass> createCallConvLoweringPass();
4548

clang/include/clang/CIR/Dialect/Passes.td

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,27 @@ def LibOpt : Pass<"cir-lib-opt"> {
180180
];
181181
}
182182

183+
def ABILowering : Pass<"cir-abi-lowering"> {
184+
let summary = "Expands ABI-dependent types and operations";
185+
let description = [{
186+
This pass expands ABI-dependent CIR types and operations to more "primitive"
187+
ABI-independent CIR types and operations according to the target ABI
188+
specification.
189+
190+
Some CIR types, such as pointers to members, may have different layouts and
191+
representations under different target ABIs. This pass expands these types
192+
to their underlying representations as specified by the target ABI. For
193+
example, when targeting Itanium ABI, this pass will replace pointers to
194+
member functions with a struct with two ptrdiff_t fields.
195+
196+
Similarly, some CIR operations may also behave differently under different
197+
target ABIs. This pass also expands these operations to more "primitive"
198+
CIR operations as specified by the target ABI.
199+
}];
200+
let constructor = "mlir::createABILoweringPass()";
201+
let dependentDialects = ["cir::CIRDialect"];
202+
}
203+
183204
def CallConvLowering : Pass<"cir-call-conv-lowering"> {
184205
let summary = "Handle calling conventions for CIR functions";
185206
let description = [{

clang/lib/CIR/CodeGen/CIRPasses.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ mlir::LogicalResult runCIRToCIRPasses(
9797
namespace mlir {
9898

9999
void populateCIRPreLoweringPasses(OpPassManager &pm, bool useCCLowering) {
100+
pm.addPass(createABILoweringPass());
100101
if (useCCLowering)
101102
pm.addPass(createCallConvLoweringPass());
102103
pm.addPass(createHoistAllocasPass());
Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
//===- ABILowering.cpp - Expands ABI-dependent types and operations -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements the CIR ABI lowering pass which expands ABI-dependent
10+
// types and operations to equivalent ABI-independent types and operations.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "TargetLowering/LowerModule.h"
15+
#include "mlir/IR/BuiltinOps.h"
16+
#include "mlir/IR/PatternMatch.h"
17+
#include "mlir/Interfaces/DataLayoutInterfaces.h"
18+
#include "mlir/Pass/Pass.h"
19+
#include "mlir/Transforms/DialectConversion.h"
20+
#include "clang/CIR/Dialect/IR/CIRDialect.h"
21+
#include "clang/CIR/Dialect/Passes.h"
22+
23+
#define GEN_PASS_DEF_ABILOWERING
24+
#include "clang/CIR/Dialect/Passes.h.inc"
25+
26+
namespace cir {
27+
namespace {
28+
29+
template <typename Op>
30+
class CIROpABILoweringPattern : public mlir::OpConversionPattern<Op> {
31+
protected:
32+
mlir::DataLayout *dataLayout;
33+
cir::LowerModule *lowerModule;
34+
35+
public:
36+
CIROpABILoweringPattern(mlir::MLIRContext *context,
37+
const mlir::TypeConverter &typeConverter,
38+
mlir::DataLayout &dataLayout,
39+
cir::LowerModule &lowerModule)
40+
: mlir::OpConversionPattern<Op>(typeConverter, context),
41+
dataLayout(&dataLayout), lowerModule(&lowerModule) {}
42+
};
43+
44+
#define CIR_ABI_LOWERING_PATTERN(name, operation) \
45+
struct name : CIROpABILoweringPattern<operation> { \
46+
using CIROpABILoweringPattern<operation>::CIROpABILoweringPattern; \
47+
\
48+
mlir::LogicalResult \
49+
matchAndRewrite(operation op, OpAdaptor adaptor, \
50+
mlir::ConversionPatternRewriter &rewriter) const override; \
51+
};
52+
CIR_ABI_LOWERING_PATTERN(CIRCastOpABILowering, cir::CastOp)
53+
CIR_ABI_LOWERING_PATTERN(CIRGlobalOpABILowering, cir::GlobalOp)
54+
CIR_ABI_LOWERING_PATTERN(CIRConstantOpABILowering, cir::ConstantOp)
55+
CIR_ABI_LOWERING_PATTERN(CIRBaseDataMemberOpABILowering, cir::BaseDataMemberOp)
56+
CIR_ABI_LOWERING_PATTERN(CIRBaseMethodOpABILowering, cir::BaseMethodOp)
57+
CIR_ABI_LOWERING_PATTERN(CIRCmpOpABILowering, cir::CmpOp)
58+
CIR_ABI_LOWERING_PATTERN(CIRDerivedDataMemberOpABILowering,
59+
cir::DerivedDataMemberOp)
60+
CIR_ABI_LOWERING_PATTERN(CIRDerivedMethodOpABILowering, cir::DerivedMethodOp)
61+
CIR_ABI_LOWERING_PATTERN(CIRGetMethodOpABILowering, cir::GetMethodOp)
62+
CIR_ABI_LOWERING_PATTERN(CIRGetRuntimeMemberOpABILowering,
63+
cir::GetRuntimeMemberOp)
64+
#undef CIR_ABI_LOWERING_PATTERN
65+
66+
mlir::LogicalResult CIRCastOpABILowering::matchAndRewrite(
67+
cir::CastOp op, OpAdaptor adaptor,
68+
mlir::ConversionPatternRewriter &rewriter) const {
69+
switch (op.getKind()) {
70+
case cir::CastKind::bitcast: {
71+
if (!mlir::isa<cir::DataMemberType, cir::MethodType>(op.getSrc().getType()))
72+
break;
73+
74+
mlir::Type destTy = getTypeConverter()->convertType(op.getType());
75+
mlir::Value loweredResult;
76+
if (mlir::isa<cir::DataMemberType>(op.getSrc().getType()))
77+
loweredResult = lowerModule->getCXXABI().lowerDataMemberBitcast(
78+
op, destTy, adaptor.getSrc(), rewriter);
79+
else
80+
loweredResult = lowerModule->getCXXABI().lowerMethodBitcast(
81+
op, destTy, adaptor.getSrc(), rewriter);
82+
rewriter.replaceOp(op, loweredResult);
83+
return mlir::success();
84+
}
85+
case cir::CastKind::member_ptr_to_bool: {
86+
mlir::Value loweredResult;
87+
if (mlir::isa<cir::MethodType>(op.getSrc().getType()))
88+
loweredResult = lowerModule->getCXXABI().lowerMethodToBoolCast(
89+
op, adaptor.getSrc(), rewriter);
90+
else
91+
loweredResult = lowerModule->getCXXABI().lowerDataMemberToBoolCast(
92+
op, adaptor.getSrc(), rewriter);
93+
rewriter.replaceOp(op, loweredResult);
94+
return mlir::success();
95+
}
96+
default:
97+
break;
98+
}
99+
100+
return mlir::failure();
101+
}
102+
103+
mlir::LogicalResult CIRGlobalOpABILowering::matchAndRewrite(
104+
cir::GlobalOp op, OpAdaptor adaptor,
105+
mlir::ConversionPatternRewriter &rewriter) const {
106+
std::optional<mlir::Attribute> init = op.getInitialValue();
107+
if (!init.has_value())
108+
return mlir::failure();
109+
110+
if (auto dataMemberAttr = mlir::dyn_cast<cir::DataMemberAttr>(*init)) {
111+
mlir::DataLayout layout(op->getParentOfType<mlir::ModuleOp>());
112+
mlir::TypedAttr abiValue = lowerModule->getCXXABI().lowerDataMemberConstant(
113+
dataMemberAttr, layout, *typeConverter);
114+
auto abiOp = mlir::cast<GlobalOp>(rewriter.clone(*op.getOperation()));
115+
abiOp.setInitialValueAttr(abiValue);
116+
abiOp.setSymType(abiValue.getType());
117+
rewriter.replaceOp(op, abiOp);
118+
return mlir::success();
119+
}
120+
121+
return mlir::success();
122+
}
123+
124+
mlir::LogicalResult CIRConstantOpABILowering::matchAndRewrite(
125+
cir::ConstantOp op, OpAdaptor adaptor,
126+
mlir::ConversionPatternRewriter &rewriter) const {
127+
if (mlir::isa<cir::DataMemberType>(op.getType())) {
128+
auto dataMember = mlir::cast<cir::DataMemberAttr>(op.getValue());
129+
mlir::DataLayout layout(op->getParentOfType<mlir::ModuleOp>());
130+
mlir::TypedAttr abiValue = lowerModule->getCXXABI().lowerDataMemberConstant(
131+
dataMember, layout, *typeConverter);
132+
rewriter.replaceOpWithNewOp<ConstantOp>(op, abiValue);
133+
return mlir::success();
134+
}
135+
if (mlir::isa<cir::MethodType>(op.getType())) {
136+
auto method = mlir::cast<cir::MethodAttr>(op.getValue());
137+
mlir::DataLayout layout(op->getParentOfType<mlir::ModuleOp>());
138+
mlir::TypedAttr abiValue = lowerModule->getCXXABI().lowerMethodConstant(
139+
method, layout, *typeConverter);
140+
rewriter.replaceOpWithNewOp<ConstantOp>(op, abiValue);
141+
return mlir::success();
142+
}
143+
144+
return mlir::failure();
145+
}
146+
147+
mlir::LogicalResult CIRBaseDataMemberOpABILowering::matchAndRewrite(
148+
cir::BaseDataMemberOp op, OpAdaptor adaptor,
149+
mlir::ConversionPatternRewriter &rewriter) const {
150+
mlir::Value loweredResult = lowerModule->getCXXABI().lowerBaseDataMember(
151+
op, adaptor.getSrc(), rewriter);
152+
rewriter.replaceOp(op, loweredResult);
153+
return mlir::success();
154+
}
155+
156+
mlir::LogicalResult CIRBaseMethodOpABILowering::matchAndRewrite(
157+
cir::BaseMethodOp op, OpAdaptor adaptor,
158+
mlir::ConversionPatternRewriter &rewriter) const {
159+
mlir::Value loweredResult =
160+
lowerModule->getCXXABI().lowerBaseMethod(op, adaptor.getSrc(), rewriter);
161+
rewriter.replaceOp(op, loweredResult);
162+
return mlir::success();
163+
}
164+
165+
mlir::LogicalResult CIRCmpOpABILowering::matchAndRewrite(
166+
cir::CmpOp op, OpAdaptor adaptor,
167+
mlir::ConversionPatternRewriter &rewriter) const {
168+
auto type = op.getLhs().getType();
169+
if (!mlir::isa<cir::DataMemberType, cir::MethodType>(type))
170+
return mlir::failure();
171+
172+
mlir::Value loweredResult;
173+
if (mlir::isa<cir::DataMemberType>(type))
174+
loweredResult = lowerModule->getCXXABI().lowerDataMemberCmp(
175+
op, adaptor.getLhs(), adaptor.getRhs(), rewriter);
176+
else
177+
loweredResult = lowerModule->getCXXABI().lowerMethodCmp(
178+
op, adaptor.getLhs(), adaptor.getRhs(), rewriter);
179+
180+
rewriter.replaceOp(op, loweredResult);
181+
return mlir::success();
182+
}
183+
184+
mlir::LogicalResult CIRDerivedDataMemberOpABILowering::matchAndRewrite(
185+
cir::DerivedDataMemberOp op, OpAdaptor adaptor,
186+
mlir::ConversionPatternRewriter &rewriter) const {
187+
mlir::Value loweredResult = lowerModule->getCXXABI().lowerDerivedDataMember(
188+
op, adaptor.getSrc(), rewriter);
189+
rewriter.replaceOp(op, loweredResult);
190+
return mlir::success();
191+
}
192+
193+
mlir::LogicalResult CIRDerivedMethodOpABILowering::matchAndRewrite(
194+
cir::DerivedMethodOp op, OpAdaptor adaptor,
195+
mlir::ConversionPatternRewriter &rewriter) const {
196+
mlir::Value loweredResult = lowerModule->getCXXABI().lowerDerivedMethod(
197+
op, adaptor.getSrc(), rewriter);
198+
rewriter.replaceOp(op, loweredResult);
199+
return mlir::success();
200+
}
201+
202+
mlir::LogicalResult CIRGetMethodOpABILowering::matchAndRewrite(
203+
cir::GetMethodOp op, OpAdaptor adaptor,
204+
mlir::ConversionPatternRewriter &rewriter) const {
205+
mlir::Value loweredResults[2];
206+
lowerModule->getCXXABI().lowerGetMethod(
207+
op, loweredResults, adaptor.getMethod(), adaptor.getObject(), rewriter);
208+
rewriter.replaceOp(op, loweredResults);
209+
return mlir::success();
210+
}
211+
212+
mlir::LogicalResult CIRGetRuntimeMemberOpABILowering::matchAndRewrite(
213+
cir::GetRuntimeMemberOp op, OpAdaptor adaptor,
214+
mlir::ConversionPatternRewriter &rewriter) const {
215+
mlir::Type resTy = getTypeConverter()->convertType(op.getType());
216+
mlir::Operation *llvmOp = lowerModule->getCXXABI().lowerGetRuntimeMember(
217+
op, resTy, adaptor.getAddr(), adaptor.getMember(), rewriter);
218+
rewriter.replaceOp(op, llvmOp);
219+
return mlir::success();
220+
}
221+
222+
static void prepareABITypeConverter(mlir::TypeConverter &converter,
223+
mlir::DataLayout &dataLayout,
224+
cir::LowerModule &lowerModule) {
225+
converter.addConversion([&](mlir::Type type) -> mlir::Type { return type; });
226+
converter.addConversion([&](cir::DataMemberType type) -> mlir::Type {
227+
mlir::Type abiType =
228+
lowerModule.getCXXABI().lowerDataMemberType(type, converter);
229+
return converter.convertType(abiType);
230+
});
231+
converter.addConversion([&](cir::MethodType type) -> mlir::Type {
232+
mlir::Type abiType =
233+
lowerModule.getCXXABI().lowerMethodType(type, converter);
234+
return converter.convertType(abiType);
235+
});
236+
}
237+
238+
static void populateABILoweringPatterns(mlir::RewritePatternSet &patterns,
239+
mlir::TypeConverter &converter,
240+
mlir::DataLayout &dataLayout,
241+
cir::LowerModule &lowerModule) {
242+
patterns.add<
243+
// clang-format off
244+
CIRBaseDataMemberOpABILowering,
245+
CIRBaseMethodOpABILowering,
246+
CIRCastOpABILowering,
247+
CIRCmpOpABILowering,
248+
CIRConstantOpABILowering,
249+
CIRDerivedDataMemberOpABILowering,
250+
CIRDerivedMethodOpABILowering,
251+
CIRGetMethodOpABILowering,
252+
CIRGetRuntimeMemberOpABILowering,
253+
CIRGlobalOpABILowering
254+
// clang-format on
255+
>(patterns.getContext(), converter, dataLayout, lowerModule);
256+
}
257+
258+
//===----------------------------------------------------------------------===//
259+
// The Pass
260+
//===----------------------------------------------------------------------===//
261+
262+
struct ABILoweringPass : ::impl::ABILoweringBase<ABILoweringPass> {
263+
using ABILoweringBase::ABILoweringBase;
264+
265+
void runOnOperation() override;
266+
llvm::StringRef getArgument() const override { return "cir-abi-lowering"; };
267+
};
268+
269+
void ABILoweringPass::runOnOperation() {
270+
auto module = mlir::cast<mlir::ModuleOp>(getOperation());
271+
mlir::MLIRContext *ctx = module.getContext();
272+
273+
mlir::PatternRewriter rewriter(ctx);
274+
std::unique_ptr<cir::LowerModule> lowerModule =
275+
cir::createLowerModule(module, rewriter);
276+
277+
mlir::DataLayout dataLayout(module);
278+
mlir::TypeConverter converter;
279+
prepareABITypeConverter(converter, dataLayout, *lowerModule);
280+
281+
mlir::RewritePatternSet patterns(ctx);
282+
populateABILoweringPatterns(patterns, converter, dataLayout, *lowerModule);
283+
284+
mlir::ConversionTarget target(*ctx);
285+
target.addLegalOp<mlir::ModuleOp>();
286+
target.addLegalDialect<cir::CIRDialect>();
287+
288+
// TODO: mark operations working on member pointers as illegal
289+
290+
// Illegal: base-to-derived and derived-to-base conversions on member pointers
291+
target.addIllegalOp<cir::BaseDataMemberOp, cir::BaseMethodOp,
292+
cir::DerivedDataMemberOp, cir::DerivedMethodOp>();
293+
// Illegal: indirection on member pointers
294+
target.addIllegalOp<cir::GetRuntimeMemberOp, cir::GetMethodOp>();
295+
296+
if (failed(mlir::applyPartialConversion(module, target, std::move(patterns))))
297+
signalPassFailure();
298+
}
299+
300+
} // namespace
301+
} // namespace cir
302+
303+
std::unique_ptr<mlir::Pass> mlir::createABILoweringPass() {
304+
return std::make_unique<cir::ABILoweringPass>();
305+
}

clang/lib/CIR/Dialect/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ add_clang_library(MLIRCIRTransforms
1212
FlattenCFG.cpp
1313
GotoSolver.cpp
1414
SCFPrepare.cpp
15+
ABILowering.cpp
1516
CallConvLowering.cpp
1617
HoistAllocas.cpp
1718

0 commit comments

Comments
 (0)