Skip to content

Commit 02218df

Browse files
committed
[CIR] Initial implementation of lowering CIR to MLIR
Add support for lowering CIR to MLIR and emitting an MLIR text file. Lowering of global pointers is not yet supported.
1 parent ca1833b commit 02218df

File tree

14 files changed

+370
-0
lines changed

14 files changed

+370
-0
lines changed

clang/include/clang/CIR/CIRGenerator.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class CIRGenerator : public clang::ASTConsumer {
5555
void Initialize(clang::ASTContext &astContext) override;
5656
bool HandleTopLevelDecl(clang::DeclGroupRef group) override;
5757
mlir::ModuleOp getModule() const;
58+
mlir::MLIRContext &getContext() { return *mlirContext; }
5859
};
5960

6061
} // namespace cir

clang/include/clang/CIR/FrontendAction/CIRGenAction.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class CIRGenAction : public clang::ASTFrontendAction {
2929
EmitCIR,
3030
EmitLLVM,
3131
EmitBC,
32+
EmitMLIR,
3233
EmitObj,
3334
};
3435

@@ -59,6 +60,13 @@ class EmitCIRAction : public CIRGenAction {
5960
EmitCIRAction(mlir::MLIRContext *MLIRCtx = nullptr);
6061
};
6162

63+
class EmitMLIRAction : public CIRGenAction {
64+
virtual void anchor();
65+
66+
public:
67+
EmitMLIRAction(mlir::MLIRContext *MLIRCtx = nullptr);
68+
};
69+
6270
class EmitLLVMAction : public CIRGenAction {
6371
virtual void anchor();
6472

clang/include/clang/CIR/LowerToLLVM.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class Module;
2020
} // namespace llvm
2121

2222
namespace mlir {
23+
class MLIRContext;
2324
class ModuleOp;
2425
} // namespace mlir
2526

@@ -30,6 +31,9 @@ std::unique_ptr<llvm::Module>
3031
lowerDirectlyFromCIRToLLVMIR(mlir::ModuleOp mlirModule,
3132
llvm::LLVMContext &llvmCtx);
3233
} // namespace direct
34+
35+
mlir::ModuleOp lowerFromCIRToMLIR(mlir::ModuleOp mlirModule,
36+
mlir::MLIRContext &mlirCtx);
3337
} // namespace cir
3438

3539
#endif // CLANG_CIR_LOWERTOLLVM_H

clang/include/clang/Driver/Options.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2987,6 +2987,8 @@ defm clangir : BoolFOption<"clangir",
29872987
BothFlags<[], [ClangOption, CC1Option], "">>;
29882988
def emit_cir : Flag<["-"], "emit-cir">, Visibility<[ClangOption, CC1Option]>,
29892989
Group<Action_Group>, HelpText<"Build ASTs and then lower to ClangIR">;
2990+
def emit_cir_mlir : Flag<["-"], "emit-cir-mlir">, Visibility<[CC1Option]>, Group<Action_Group>,
2991+
HelpText<"Build ASTs and then lower through ClangIR to MLIR, emit the .milr file">;
29902992
/// ClangIR-specific options - END
29912993

29922994
def flto_EQ : Joined<["-"], "flto=">,

clang/include/clang/Frontend/FrontendOptions.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ enum ActionKind {
6868
/// Emit a .cir file
6969
EmitCIR,
7070

71+
/// Emit a .mlir file
72+
EmitMLIR,
73+
7174
/// Emit a .ll file.
7275
EmitLLVM,
7376

clang/lib/CIR/FrontendAction/CIRGenAction.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ static BackendAction
2424
getBackendActionFromOutputType(CIRGenAction::OutputType Action) {
2525
switch (Action) {
2626
case CIRGenAction::OutputType::EmitCIR:
27+
case CIRGenAction::OutputType::EmitMLIR:
2728
assert(false &&
2829
"Unsupported output type for getBackendActionFromOutputType!");
2930
break; // Unreachable, but fall through to report that
@@ -82,6 +83,7 @@ class CIRGenConsumer : public clang::ASTConsumer {
8283
void HandleTranslationUnit(ASTContext &C) override {
8384
Gen->HandleTranslationUnit(C);
8485
mlir::ModuleOp MlirModule = Gen->getModule();
86+
mlir::MLIRContext &MlirCtx = Gen->getContext();
8587
switch (Action) {
8688
case CIRGenAction::OutputType::EmitCIR:
8789
if (OutputStream && MlirModule) {
@@ -90,6 +92,15 @@ class CIRGenConsumer : public clang::ASTConsumer {
9092
MlirModule->print(*OutputStream, Flags);
9193
}
9294
break;
95+
case CIRGenAction::OutputType::EmitMLIR: {
96+
auto LoweredMlirModule = lowerFromCIRToMLIR(MlirModule, MlirCtx);
97+
assert(OutputStream && "No output stream when lowering to MLIR!");
98+
// FIXME: we cannot roundtrip prettyForm=true right now.
99+
mlir::OpPrintingFlags Flags;
100+
Flags.enableDebugInfo(/*enable=*/true, /*prettyForm=*/false);
101+
LoweredMlirModule->print(*OutputStream, Flags);
102+
break;
103+
}
93104
case CIRGenAction::OutputType::EmitLLVM:
94105
case CIRGenAction::OutputType::EmitBC:
95106
case CIRGenAction::OutputType::EmitObj:
@@ -124,6 +135,8 @@ getOutputStream(CompilerInstance &CI, StringRef InFile,
124135
return CI.createDefaultOutputFile(false, InFile, "s");
125136
case CIRGenAction::OutputType::EmitCIR:
126137
return CI.createDefaultOutputFile(false, InFile, "cir");
138+
case CIRGenAction::OutputType::EmitMLIR:
139+
return CI.createDefaultOutputFile(false, InFile, "mlir");
127140
case CIRGenAction::OutputType::EmitLLVM:
128141
return CI.createDefaultOutputFile(false, InFile, "ll");
129142
case CIRGenAction::OutputType::EmitBC:
@@ -155,6 +168,10 @@ void EmitCIRAction::anchor() {}
155168
EmitCIRAction::EmitCIRAction(mlir::MLIRContext *MLIRCtx)
156169
: CIRGenAction(OutputType::EmitCIR, MLIRCtx) {}
157170

171+
void EmitMLIRAction::anchor() {}
172+
EmitMLIRAction::EmitMLIRAction(mlir::MLIRContext *MLIRCtx)
173+
: CIRGenAction(OutputType::EmitMLIR, MLIRCtx) {}
174+
158175
void EmitLLVMAction::anchor() {}
159176
EmitLLVMAction::EmitLLVMAction(mlir::MLIRContext *MLIRCtx)
160177
: CIRGenAction(OutputType::EmitLLVM, MLIRCtx) {}

clang/lib/CIR/FrontendAction/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ add_clang_library(clangCIRFrontendAction
1717
clangFrontend
1818
clangCIR
1919
clangCIRLoweringDirectToLLVM
20+
clangCIRLoweringThroughMLIR
2021
clangCodeGen
2122
MLIRCIR
2223
MLIRIR

clang/lib/CIR/Lowering/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
add_subdirectory(DirectToLLVM)
2+
add_subdirectory(ThroughMLIR)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
set(LLVM_LINK_COMPONENTS
2+
Core
3+
Support
4+
)
5+
6+
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
7+
8+
add_clang_library(clangCIRLoweringThroughMLIR
9+
LowerCIRToMLIR.cpp
10+
11+
DEPENDS
12+
LINK_LIBS
13+
MLIRIR
14+
${dialect_libs}
15+
MLIRCIR
16+
)
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
//====- LowerCIRToMLIR.cpp - Lowering from CIR to MLIR --------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements lowering of CIR operations to MLIR.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
14+
#include "mlir/IR/BuiltinDialect.h"
15+
#include "mlir/Pass/Pass.h"
16+
#include "mlir/Pass/PassManager.h"
17+
#include "mlir/Transforms/DialectConversion.h"
18+
#include "clang/CIR/Dialect/IR/CIRDialect.h"
19+
#include "clang/CIR/Dialect/IR/CIRTypes.h"
20+
#include "clang/CIR/LowerToLLVM.h"
21+
#include "clang/CIR/MissingFeatures.h"
22+
#include "llvm/ADT/TypeSwitch.h"
23+
#include "llvm/Support/TimeProfiler.h"
24+
25+
using namespace cir;
26+
using namespace llvm;
27+
28+
namespace cir {
29+
30+
struct ConvertCIRToMLIRPass
31+
: public mlir::PassWrapper<ConvertCIRToMLIRPass,
32+
mlir::OperationPass<mlir::ModuleOp>> {
33+
void getDependentDialects(mlir::DialectRegistry &registry) const override {
34+
registry.insert<mlir::BuiltinDialect, mlir::memref::MemRefDialect>();
35+
}
36+
void runOnOperation() final;
37+
38+
StringRef getDescription() const override {
39+
return "Convert the CIR dialect module to MLIR standard dialects";
40+
}
41+
42+
StringRef getArgument() const override { return "cir-to-mlir"; }
43+
};
44+
45+
class CIRGlobalOpLowering : public mlir::OpConversionPattern<cir::GlobalOp> {
46+
public:
47+
using OpConversionPattern<cir::GlobalOp>::OpConversionPattern;
48+
mlir::LogicalResult
49+
matchAndRewrite(cir::GlobalOp op, OpAdaptor adaptor,
50+
mlir::ConversionPatternRewriter &rewriter) const override {
51+
auto moduleOp = op->getParentOfType<mlir::ModuleOp>();
52+
if (!moduleOp)
53+
return mlir::failure();
54+
55+
mlir::OpBuilder b(moduleOp.getContext());
56+
57+
const auto cirSymType = op.getSymType();
58+
assert(!cir::MissingFeatures::convertTypeForMemory());
59+
auto convertedType = getTypeConverter()->convertType(cirSymType);
60+
if (!convertedType)
61+
return mlir::failure();
62+
auto memrefType = dyn_cast<mlir::MemRefType>(convertedType);
63+
if (!memrefType)
64+
memrefType = mlir::MemRefType::get({}, convertedType);
65+
// Add an optional alignment to the global memref.
66+
assert(!cir::MissingFeatures::opGlobalAlignment());
67+
mlir::IntegerAttr memrefAlignment = mlir::IntegerAttr();
68+
// Add an optional initial value to the global memref.
69+
mlir::Attribute initialValue = mlir::Attribute();
70+
std::optional<mlir::Attribute> init = op.getInitialValue();
71+
if (init.has_value()) {
72+
initialValue =
73+
llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(init.value())
74+
.Case<cir::IntAttr>([&](cir::IntAttr attr) {
75+
auto rtt = mlir::RankedTensorType::get({}, convertedType);
76+
return mlir::DenseIntElementsAttr::get(rtt, attr.getValue());
77+
})
78+
.Case<cir::FPAttr>([&](cir::FPAttr attr) {
79+
auto rtt = mlir::RankedTensorType::get({}, convertedType);
80+
return mlir::DenseFPElementsAttr::get(rtt, attr.getValue());
81+
})
82+
.Default([&](mlir::Attribute attr) {
83+
llvm_unreachable("GlobalOp lowering with initial value is not "
84+
"fully supported yet");
85+
return mlir::Attribute();
86+
});
87+
}
88+
89+
// Add symbol visibility
90+
assert(!cir::MissingFeatures::opGlobalLinkage());
91+
std::string symVisibility = "public";
92+
93+
assert(!cir::MissingFeatures::opGlobalConstant());
94+
bool isConstant = false;
95+
96+
rewriter.replaceOpWithNewOp<mlir::memref::GlobalOp>(
97+
op, b.getStringAttr(op.getSymName()),
98+
/*sym_visibility=*/b.getStringAttr(symVisibility),
99+
/*type=*/memrefType, initialValue,
100+
/*constant=*/isConstant,
101+
/*alignment=*/memrefAlignment);
102+
103+
return mlir::success();
104+
}
105+
};
106+
107+
void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
108+
mlir::TypeConverter &converter) {
109+
patterns.add<CIRGlobalOpLowering>(converter, patterns.getContext());
110+
}
111+
112+
static mlir::TypeConverter prepareTypeConverter() {
113+
mlir::TypeConverter converter;
114+
converter.addConversion([&](cir::PointerType type) -> mlir::Type {
115+
assert(!cir::MissingFeatures::convertTypeForMemory());
116+
mlir::Type ty = converter.convertType(type.getPointee());
117+
// FIXME: The pointee type might not be converted (e.g. struct)
118+
if (!ty)
119+
return nullptr;
120+
return mlir::MemRefType::get({}, ty);
121+
});
122+
converter.addConversion(
123+
[&](mlir::IntegerType type) -> mlir::Type { return type; });
124+
converter.addConversion(
125+
[&](mlir::FloatType type) -> mlir::Type { return type; });
126+
converter.addConversion([&](cir::VoidType type) -> mlir::Type { return {}; });
127+
converter.addConversion([&](cir::IntType type) -> mlir::Type {
128+
// arith dialect ops doesn't take signed integer -- drop cir sign here
129+
return mlir::IntegerType::get(
130+
type.getContext(), type.getWidth(),
131+
mlir::IntegerType::SignednessSemantics::Signless);
132+
});
133+
converter.addConversion([&](cir::SingleType type) -> mlir::Type {
134+
return mlir::Float32Type::get(type.getContext());
135+
});
136+
converter.addConversion([&](cir::DoubleType type) -> mlir::Type {
137+
return mlir::Float64Type::get(type.getContext());
138+
});
139+
converter.addConversion([&](cir::FP80Type type) -> mlir::Type {
140+
return mlir::Float80Type::get(type.getContext());
141+
});
142+
converter.addConversion([&](cir::LongDoubleType type) -> mlir::Type {
143+
return converter.convertType(type.getUnderlying());
144+
});
145+
converter.addConversion([&](cir::FP128Type type) -> mlir::Type {
146+
return mlir::Float128Type::get(type.getContext());
147+
});
148+
converter.addConversion([&](cir::FP16Type type) -> mlir::Type {
149+
return mlir::Float16Type::get(type.getContext());
150+
});
151+
converter.addConversion([&](cir::BF16Type type) -> mlir::Type {
152+
return mlir::BFloat16Type::get(type.getContext());
153+
});
154+
155+
return converter;
156+
}
157+
158+
void ConvertCIRToMLIRPass::runOnOperation() {
159+
auto module = getOperation();
160+
161+
auto converter = prepareTypeConverter();
162+
163+
mlir::RewritePatternSet patterns(&getContext());
164+
165+
populateCIRToMLIRConversionPatterns(patterns, converter);
166+
167+
mlir::ConversionTarget target(getContext());
168+
target.addLegalOp<mlir::ModuleOp>();
169+
target.addLegalDialect<mlir::memref::MemRefDialect>();
170+
target.addIllegalDialect<cir::CIRDialect>();
171+
172+
if (failed(applyPartialConversion(module, target, std::move(patterns))))
173+
signalPassFailure();
174+
}
175+
176+
std::unique_ptr<mlir::Pass> createConvertCIRToMLIRPass() {
177+
return std::make_unique<ConvertCIRToMLIRPass>();
178+
}
179+
180+
mlir::ModuleOp lowerFromCIRToMLIR(mlir::ModuleOp mlirModule,
181+
mlir::MLIRContext &mlirCtx) {
182+
llvm::TimeTraceScope scope("Lower CIR To MLIR");
183+
184+
mlir::PassManager pm(&mlirCtx);
185+
186+
pm.addPass(createConvertCIRToMLIRPass());
187+
188+
auto result = !mlir::failed(pm.run(mlirModule));
189+
if (!result)
190+
llvm::report_fatal_error(
191+
"The pass manager failed to lower CIR to MLIR standard dialects!");
192+
193+
// Now that we ran all the lowering passes, verify the final output.
194+
if (mlirModule.verify().failed())
195+
llvm::report_fatal_error(
196+
"Verification of the final MLIR in standard dialects failed!");
197+
198+
return mlirModule;
199+
}
200+
201+
} // namespace cir

0 commit comments

Comments
 (0)