1010//
1111// ===----------------------------------------------------------------------===//
1212
13- #include " clang/CIR/ LowerToLLVM.h"
13+ #include " LowerToLLVM.h"
1414
15+ #include " mlir/Conversion/LLVMCommon/TypeConverter.h"
16+ #include " mlir/Dialect/DLTI/DLTI.h"
17+ #include " mlir/Dialect/Func/IR/FuncOps.h"
18+ #include " mlir/Dialect/LLVMIR/LLVMDialect.h"
19+ #include " mlir/IR/BuiltinDialect.h"
1520#include " mlir/IR/BuiltinOps.h"
21+ #include " mlir/Pass/Pass.h"
22+ #include " mlir/Pass/PassManager.h"
23+ #include " mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
24+ #include " mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
25+ #include " mlir/Target/LLVMIR/Export.h"
26+ #include " mlir/Transforms/DialectConversion.h"
27+ #include " clang/CIR/Dialect/IR/CIRDialect.h"
28+ #include " clang/CIR/MissingFeatures.h"
1629#include " llvm/IR/Module.h"
1730#include " llvm/Support/TimeProfiler.h"
1831
@@ -22,16 +35,165 @@ using namespace llvm;
2235namespace cir {
2336namespace direct {
2437
38+ // This pass requires the CIR to be in a "flat" state. All blocks in each
39+ // function must belong to the parent region. Once scopes and control flow
40+ // are implemented in CIR, a pass will be run before this one to flatten
41+ // the CIR and get it into the state that this pass requires.
42+ struct ConvertCIRToLLVMPass
43+ : public mlir::PassWrapper<ConvertCIRToLLVMPass,
44+ mlir::OperationPass<mlir::ModuleOp>> {
45+ void getDependentDialects (mlir::DialectRegistry ®istry) const override {
46+ registry.insert <mlir::BuiltinDialect, mlir::DLTIDialect,
47+ mlir::LLVM::LLVMDialect, mlir::func::FuncDialect>();
48+ }
49+ void runOnOperation () final ;
50+
51+ StringRef getDescription () const override {
52+ return " Convert the prepared CIR dialect module to LLVM dialect" ;
53+ }
54+
55+ StringRef getArgument () const override { return " cir-flat-to-llvm" ; }
56+ };
57+
58+ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite (
59+ cir::GlobalOp op, OpAdaptor adaptor,
60+ mlir::ConversionPatternRewriter &rewriter) const {
61+
62+ // Fetch required values to create LLVM op.
63+ const mlir::Type cirSymType = op.getSymType ();
64+
65+ // This is the LLVM dialect type.
66+ const mlir::Type llvmType = getTypeConverter ()->convertType (cirSymType);
67+ // FIXME: These default values are placeholders until the the equivalent
68+ // attributes are available on cir.global ops.
69+ assert (!cir::MissingFeatures::opGlobalConstant ());
70+ const bool isConst = false ;
71+ assert (!cir::MissingFeatures::addressSpace ());
72+ const unsigned addrSpace = 0 ;
73+ assert (!cir::MissingFeatures::opGlobalDSOLocal ());
74+ const bool isDsoLocal = true ;
75+ assert (!cir::MissingFeatures::opGlobalThreadLocal ());
76+ const bool isThreadLocal = false ;
77+ assert (!cir::MissingFeatures::opGlobalAlignment ());
78+ const uint64_t alignment = 0 ;
79+ assert (!cir::MissingFeatures::opGlobalLinkage ());
80+ const mlir::LLVM::Linkage linkage = mlir::LLVM::Linkage::External;
81+ const StringRef symbol = op.getSymName ();
82+ std::optional<mlir::Attribute> init = op.getInitialValue ();
83+
84+ SmallVector<mlir::NamedAttribute> attributes;
85+
86+ if (init.has_value ()) {
87+ if (const auto fltAttr = mlir::dyn_cast<cir::FPAttr>(init.value ())) {
88+ // Initializer is a constant floating-point number: convert to MLIR
89+ // builtin constant.
90+ init = rewriter.getFloatAttr (llvmType, fltAttr.getValue ());
91+ } else if (const auto intAttr =
92+ mlir::dyn_cast<cir::IntAttr>(init.value ())) {
93+ // Initializer is a constant array: convert it to a compatible llvm init.
94+ init = rewriter.getIntegerAttr (llvmType, intAttr.getValue ());
95+ } else {
96+ op.emitError () << " unsupported initializer '" << init.value () << " '" ;
97+ return mlir::failure ();
98+ }
99+ }
100+
101+ // Rewrite op.
102+ rewriter.replaceOpWithNewOp <mlir::LLVM::GlobalOp>(
103+ op, llvmType, isConst, linkage, symbol, init.value_or (mlir::Attribute ()),
104+ alignment, addrSpace, isDsoLocal, isThreadLocal,
105+ /* comdat=*/ mlir::SymbolRefAttr (), attributes);
106+
107+ return mlir::success ();
108+ }
109+
110+ static void prepareTypeConverter (mlir::LLVMTypeConverter &converter,
111+ mlir::DataLayout &dataLayout) {
112+ converter.addConversion ([&](cir::IntType type) -> mlir::Type {
113+ // LLVM doesn't work with signed types, so we drop the CIR signs here.
114+ return mlir::IntegerType::get (type.getContext (), type.getWidth ());
115+ });
116+ converter.addConversion ([&](cir::SingleType type) -> mlir::Type {
117+ return mlir::Float32Type::get (type.getContext ());
118+ });
119+ converter.addConversion ([&](cir::DoubleType type) -> mlir::Type {
120+ return mlir::Float64Type::get (type.getContext ());
121+ });
122+ converter.addConversion ([&](cir::FP80Type type) -> mlir::Type {
123+ return mlir::Float80Type::get (type.getContext ());
124+ });
125+ converter.addConversion ([&](cir::FP128Type type) -> mlir::Type {
126+ return mlir::Float128Type::get (type.getContext ());
127+ });
128+ converter.addConversion ([&](cir::LongDoubleType type) -> mlir::Type {
129+ return converter.convertType (type.getUnderlying ());
130+ });
131+ converter.addConversion ([&](cir::FP16Type type) -> mlir::Type {
132+ return mlir::Float16Type::get (type.getContext ());
133+ });
134+ converter.addConversion ([&](cir::BF16Type type) -> mlir::Type {
135+ return mlir::BFloat16Type::get (type.getContext ());
136+ });
137+ }
138+
139+ void ConvertCIRToLLVMPass::runOnOperation () {
140+ llvm::TimeTraceScope scope (" Convert CIR to LLVM Pass" );
141+
142+ mlir::ModuleOp module = getOperation ();
143+ mlir::DataLayout dl (module );
144+ mlir::LLVMTypeConverter converter (&getContext ());
145+ prepareTypeConverter (converter, dl);
146+
147+ mlir::RewritePatternSet patterns (&getContext ());
148+
149+ patterns.add <CIRToLLVMGlobalOpLowering>(converter, patterns.getContext (), dl);
150+
151+ mlir::ConversionTarget target (getContext ());
152+ target.addLegalOp <mlir::ModuleOp>();
153+ target.addLegalDialect <mlir::LLVM::LLVMDialect>();
154+ target.addIllegalDialect <mlir::BuiltinDialect, cir::CIRDialect,
155+ mlir::func::FuncDialect>();
156+
157+ if (failed (applyPartialConversion (module , target, std::move (patterns))))
158+ signalPassFailure ();
159+ }
160+
161+ static std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass () {
162+ return std::make_unique<ConvertCIRToLLVMPass>();
163+ }
164+
165+ static void populateCIRToLLVMPasses (mlir::OpPassManager &pm) {
166+ pm.addPass (createConvertCIRToLLVMPass ());
167+ }
168+
25169std::unique_ptr<llvm::Module>
26170lowerDirectlyFromCIRToLLVMIR (mlir::ModuleOp mlirModule, LLVMContext &llvmCtx) {
27171 llvm::TimeTraceScope scope (" lower from CIR to LLVM directly" );
28172
29- std::optional<StringRef> moduleName = mlirModule.getName ();
30- auto llvmModule = std::make_unique<llvm::Module>(
31- moduleName ? *moduleName : " CIRToLLVMModule" , llvmCtx);
173+ mlir::MLIRContext *mlirCtx = mlirModule.getContext ();
174+
175+ mlir::PassManager pm (mlirCtx);
176+ populateCIRToLLVMPasses (pm);
177+
178+ if (mlir::failed (pm.run (mlirModule))) {
179+ // FIXME: Handle any errors where they occurs and return a nullptr here.
180+ report_fatal_error (
181+ " The pass manager failed to lower CIR to LLVMIR dialect!" );
182+ }
183+
184+ mlir::registerBuiltinDialectTranslation (*mlirCtx);
185+ mlir::registerLLVMDialectTranslation (*mlirCtx);
186+
187+ llvm::TimeTraceScope translateScope (" translateModuleToLLVMIR" );
188+
189+ StringRef moduleName = mlirModule.getName ().value_or (" CIRToLLVMModule" );
190+ std::unique_ptr<llvm::Module> llvmModule =
191+ mlir::translateModuleToLLVMIR (mlirModule, llvmCtx, moduleName);
32192
33- if (!llvmModule)
193+ if (!llvmModule) {
194+ // FIXME: Handle any errors where they occurs and return a nullptr here.
34195 report_fatal_error (" Lowering from LLVMIR dialect to llvm IR failed!" );
196+ }
35197
36198 return llvmModule;
37199}
0 commit comments