1010//
1111// ===----------------------------------------------------------------------===//
1212
13- #include " LowerToLLVM.h"
13+ #include " clang/CIR/ LowerToLLVM.h"
1414
15- #include " mlir/Conversion/LLVMCommon/TypeConverter.h"
16- #include " mlir/Dialect/DLTI/DLTI.h"
17- #include " mlir/Dialect/Func/IR/FuncOps.h"
18- #include " mlir/Dialect/LLVMIR/LLVMDialect.h"
19- #include " mlir/IR/BuiltinDialect.h"
2015#include " mlir/IR/BuiltinOps.h"
21- #include " mlir/Pass/Pass.h"
22- #include " mlir/Pass/PassManager.h"
23- #include " mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
24- #include " mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
25- #include " mlir/Target/LLVMIR/Export.h"
26- #include " mlir/Transforms/DialectConversion.h"
27- #include " clang/CIR/Dialect/IR/CIRDialect.h"
28- #include " clang/CIR/MissingFeatures.h"
2916#include " llvm/IR/Module.h"
3017#include " llvm/Support/TimeProfiler.h"
3118
@@ -35,165 +22,16 @@ using namespace llvm;
3522namespace cir {
3623namespace direct {
3724
38- // This pass requires the CIR to be in a "flat" state. All blocks in each
39- // function must belong to the parent region. Once scopes and control flow
40- // are implemented in CIR, a pass will be run before this one to flatten
41- // the CIR and get it into the state that this pass requires.
42- struct ConvertCIRToLLVMPass
43- : public mlir::PassWrapper<ConvertCIRToLLVMPass,
44- mlir::OperationPass<mlir::ModuleOp>> {
45- void getDependentDialects (mlir::DialectRegistry ®istry) const override {
46- registry.insert <mlir::BuiltinDialect, mlir::DLTIDialect,
47- mlir::LLVM::LLVMDialect, mlir::func::FuncDialect>();
48- }
49- void runOnOperation () final ;
50-
51- StringRef getDescription () const override {
52- return " Convert the prepared CIR dialect module to LLVM dialect" ;
53- }
54-
55- StringRef getArgument () const override { return " cir-flat-to-llvm" ; }
56- };
57-
58- mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite (
59- cir::GlobalOp op, OpAdaptor adaptor,
60- mlir::ConversionPatternRewriter &rewriter) const {
61-
62- // Fetch required values to create LLVM op.
63- const mlir::Type cirSymType = op.getSymType ();
64-
65- // This is the LLVM dialect type.
66- const mlir::Type llvmType = getTypeConverter ()->convertType (cirSymType);
67- // FIXME: These default values are placeholders until the the equivalent
68- // attributes are available on cir.global ops.
69- assert (!cir::MissingFeatures::opGlobalConstant ());
70- const bool isConst = false ;
71- assert (!cir::MissingFeatures::addressSpace ());
72- const unsigned addrSpace = 0 ;
73- assert (!cir::MissingFeatures::opGlobalDSOLocal ());
74- const bool isDsoLocal = true ;
75- assert (!cir::MissingFeatures::opGlobalThreadLocal ());
76- const bool isThreadLocal = false ;
77- assert (!cir::MissingFeatures::opGlobalAlignment ());
78- const uint64_t alignment = 0 ;
79- assert (!cir::MissingFeatures::opGlobalLinkage ());
80- const mlir::LLVM::Linkage linkage = mlir::LLVM::Linkage::External;
81- const StringRef symbol = op.getSymName ();
82- std::optional<mlir::Attribute> init = op.getInitialValue ();
83-
84- SmallVector<mlir::NamedAttribute> attributes;
85-
86- if (init.has_value ()) {
87- if (const auto fltAttr = mlir::dyn_cast<cir::FPAttr>(init.value ())) {
88- // Initializer is a constant floating-point number: convert to MLIR
89- // builtin constant.
90- init = rewriter.getFloatAttr (llvmType, fltAttr.getValue ());
91- } else if (const auto intAttr =
92- mlir::dyn_cast<cir::IntAttr>(init.value ())) {
93- // Initializer is a constant array: convert it to a compatible llvm init.
94- init = rewriter.getIntegerAttr (llvmType, intAttr.getValue ());
95- } else {
96- op.emitError () << " unsupported initializer '" << init.value () << " '" ;
97- return mlir::failure ();
98- }
99- }
100-
101- // Rewrite op.
102- rewriter.replaceOpWithNewOp <mlir::LLVM::GlobalOp>(
103- op, llvmType, isConst, linkage, symbol, init.value_or (mlir::Attribute ()),
104- alignment, addrSpace, isDsoLocal, isThreadLocal,
105- /* comdat=*/ mlir::SymbolRefAttr (), attributes);
106-
107- return mlir::success ();
108- }
109-
110- static void prepareTypeConverter (mlir::LLVMTypeConverter &converter,
111- mlir::DataLayout &dataLayout) {
112- converter.addConversion ([&](cir::IntType type) -> mlir::Type {
113- // LLVM doesn't work with signed types, so we drop the CIR signs here.
114- return mlir::IntegerType::get (type.getContext (), type.getWidth ());
115- });
116- converter.addConversion ([&](cir::SingleType type) -> mlir::Type {
117- return mlir::Float32Type::get (type.getContext ());
118- });
119- converter.addConversion ([&](cir::DoubleType type) -> mlir::Type {
120- return mlir::Float64Type::get (type.getContext ());
121- });
122- converter.addConversion ([&](cir::FP80Type type) -> mlir::Type {
123- return mlir::Float80Type::get (type.getContext ());
124- });
125- converter.addConversion ([&](cir::FP128Type type) -> mlir::Type {
126- return mlir::Float128Type::get (type.getContext ());
127- });
128- converter.addConversion ([&](cir::LongDoubleType type) -> mlir::Type {
129- return converter.convertType (type.getUnderlying ());
130- });
131- converter.addConversion ([&](cir::FP16Type type) -> mlir::Type {
132- return mlir::Float16Type::get (type.getContext ());
133- });
134- converter.addConversion ([&](cir::BF16Type type) -> mlir::Type {
135- return mlir::BFloat16Type::get (type.getContext ());
136- });
137- }
138-
139- void ConvertCIRToLLVMPass::runOnOperation () {
140- llvm::TimeTraceScope scope (" Convert CIR to LLVM Pass" );
141-
142- mlir::ModuleOp module = getOperation ();
143- mlir::DataLayout dl (module );
144- mlir::LLVMTypeConverter converter (&getContext ());
145- prepareTypeConverter (converter, dl);
146-
147- mlir::RewritePatternSet patterns (&getContext ());
148-
149- patterns.add <CIRToLLVMGlobalOpLowering>(converter, patterns.getContext (), dl);
150-
151- mlir::ConversionTarget target (getContext ());
152- target.addLegalOp <mlir::ModuleOp>();
153- target.addLegalDialect <mlir::LLVM::LLVMDialect>();
154- target.addIllegalDialect <mlir::BuiltinDialect, cir::CIRDialect,
155- mlir::func::FuncDialect>();
156-
157- if (failed (applyPartialConversion (module , target, std::move (patterns))))
158- signalPassFailure ();
159- }
160-
161- static std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass () {
162- return std::make_unique<ConvertCIRToLLVMPass>();
163- }
164-
165- static void populateCIRToLLVMPasses (mlir::OpPassManager &pm) {
166- pm.addPass (createConvertCIRToLLVMPass ());
167- }
168-
16925std::unique_ptr<llvm::Module>
17026lowerDirectlyFromCIRToLLVMIR (mlir::ModuleOp mlirModule, LLVMContext &llvmCtx) {
17127 llvm::TimeTraceScope scope (" lower from CIR to LLVM directly" );
17228
173- mlir::MLIRContext *mlirCtx = mlirModule.getContext ();
174-
175- mlir::PassManager pm (mlirCtx);
176- populateCIRToLLVMPasses (pm);
177-
178- if (mlir::failed (pm.run (mlirModule))) {
179- // FIXME: Handle any errors where they occurs and return a nullptr here.
180- report_fatal_error (
181- " The pass manager failed to lower CIR to LLVMIR dialect!" );
182- }
183-
184- mlir::registerBuiltinDialectTranslation (*mlirCtx);
185- mlir::registerLLVMDialectTranslation (*mlirCtx);
186-
187- llvm::TimeTraceScope translateScope (" translateModuleToLLVMIR" );
188-
189- StringRef moduleName = mlirModule.getName ().value_or (" CIRToLLVMModule" );
190- std::unique_ptr<llvm::Module> llvmModule =
191- mlir::translateModuleToLLVMIR (mlirModule, llvmCtx, moduleName);
29+ std::optional<StringRef> moduleName = mlirModule.getName ();
30+ auto llvmModule = std::make_unique<llvm::Module>(
31+ moduleName ? *moduleName : " CIRToLLVMModule" , llvmCtx);
19232
193- if (!llvmModule) {
194- // FIXME: Handle any errors where they occurs and return a nullptr here.
33+ if (!llvmModule)
19534 report_fatal_error (" Lowering from LLVMIR dialect to llvm IR failed!" );
196- }
19735
19836 return llvmModule;
19937}
0 commit comments