diff --git a/clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h b/clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h new file mode 100644 index 0000000000000..bbba89cb7e3fd --- /dev/null +++ b/clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h @@ -0,0 +1,52 @@ +//===- CIRAttrVisitor.h - Visitor for CIR attributes ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the CirAttrVisitor interface. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H +#define LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H + +#include "clang/CIR/Dialect/IR/CIRAttrs.h" + +namespace cir { + +template class CirAttrVisitor { +public: + // FIXME: Create a TableGen list to automatically handle new attributes + RetTy visit(mlir::Attribute attr) { + if (const auto intAttr = mlir::dyn_cast(attr)) + return getImpl().visitCirIntAttr(intAttr); + if (const auto fltAttr = mlir::dyn_cast(attr)) + return getImpl().visitCirFPAttr(fltAttr); + if (const auto ptrAttr = mlir::dyn_cast(attr)) + return getImpl().visitCirConstPtrAttr(ptrAttr); + llvm_unreachable("unhandled attribute type"); + } + + // If the implementation chooses not to implement a certain visit + // method, fall back to the parent. + RetTy visitCirIntAttr(cir::IntAttr attr) { + return getImpl().visitCirAttr(attr); + } + RetTy visitCirFPAttr(cir::FPAttr attr) { + return getImpl().visitCirAttr(attr); + } + RetTy visitCirConstPtrAttr(cir::ConstPtrAttr attr) { + return getImpl().visitCirAttr(attr); + } + + RetTy visitCirAttr(mlir::Attribute attr) { return RetTy(); } + + ImplClass &getImpl() { return *static_cast(this); } +}; + +} // namespace cir + +#endif // LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H diff --git a/clang/include/clang/CIR/MissingFeatures.h b/clang/include/clang/CIR/MissingFeatures.h index 3c018aeea6501..d4fcd52e7e6e3 100644 --- a/clang/include/clang/CIR/MissingFeatures.h +++ b/clang/include/clang/CIR/MissingFeatures.h @@ -27,6 +27,9 @@ struct MissingFeatures { // Address space related static bool addressSpace() { return false; } + // This isn't needed until we add support for bools. + static bool convertTypeForMemory() { return false; } + // Unhandled global/linkage information. static bool opGlobalDSOLocal() { return false; } static bool opGlobalThreadLocal() { return false; } diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index af8ca7d0b89e6..d60a6b38b0c12 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -24,6 +24,7 @@ #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Export.h" #include "mlir/Transforms/DialectConversion.h" +#include "clang/CIR/Dialect/IR/CIRAttrVisitor.h" #include "clang/CIR/Dialect/IR/CIRDialect.h" #include "clang/CIR/MissingFeatures.h" #include "llvm/IR/Module.h" @@ -35,6 +36,71 @@ using namespace llvm; namespace cir { namespace direct { +class CIRAttrToValue : public CirAttrVisitor { +public: + CIRAttrToValue(mlir::Operation *parentOp, + mlir::ConversionPatternRewriter &rewriter, + const mlir::TypeConverter *converter) + : parentOp(parentOp), rewriter(rewriter), converter(converter) {} + + mlir::Value lowerCirAttrAsValue(mlir::Attribute attr) { return visit(attr); } + + mlir::Value visitCirIntAttr(cir::IntAttr intAttr) { + mlir::Location loc = parentOp->getLoc(); + return rewriter.create( + loc, converter->convertType(intAttr.getType()), intAttr.getValue()); + } + + mlir::Value visitCirFPAttr(cir::FPAttr fltAttr) { + mlir::Location loc = parentOp->getLoc(); + return rewriter.create( + loc, converter->convertType(fltAttr.getType()), fltAttr.getValue()); + } + + mlir::Value visitCirConstPtrAttr(cir::ConstPtrAttr ptrAttr) { + mlir::Location loc = parentOp->getLoc(); + if (ptrAttr.isNullValue()) { + return rewriter.create( + loc, converter->convertType(ptrAttr.getType())); + } + mlir::DataLayout layout(parentOp->getParentOfType()); + mlir::Value ptrVal = rewriter.create( + loc, + rewriter.getIntegerType(layout.getTypeSizeInBits(ptrAttr.getType())), + ptrAttr.getValue().getInt()); + return rewriter.create( + loc, converter->convertType(ptrAttr.getType()), ptrVal); + } + +private: + mlir::Operation *parentOp; + mlir::ConversionPatternRewriter &rewriter; + const mlir::TypeConverter *converter; +}; + +// This class handles rewriting initializer attributes for types that do not +// require region initialization. +class GlobalInitAttrRewriter + : public CirAttrVisitor { +public: + GlobalInitAttrRewriter(mlir::Type type, + mlir::ConversionPatternRewriter &rewriter) + : llvmType(type), rewriter(rewriter) {} + + mlir::Attribute rewriteInitAttr(mlir::Attribute attr) { return visit(attr); } + + mlir::Attribute visitCirIntAttr(cir::IntAttr attr) { + return rewriter.getIntegerAttr(llvmType, attr.getValue()); + } + mlir::Attribute visitCirFPAttr(cir::FPAttr attr) { + return rewriter.getFloatAttr(llvmType, attr.getValue()); + } + +private: + mlir::Type llvmType; + mlir::ConversionPatternRewriter &rewriter; +}; + // This pass requires the CIR to be in a "flat" state. All blocks in each // function must belong to the parent region. Once scopes and control flow // are implemented in CIR, a pass will be run before this one to flatten @@ -55,14 +121,81 @@ struct ConvertCIRToLLVMPass StringRef getArgument() const override { return "cir-flat-to-llvm"; } }; +bool CIRToLLVMGlobalOpLowering::attrRequiresRegionInitialization( + mlir::Attribute attr) const { + // There will be more cases added later. + return isa(attr); +} + +/// Replace CIR global with a region initialized LLVM global and update +/// insertion point to the end of the initializer block. +void CIRToLLVMGlobalOpLowering::setupRegionInitializedLLVMGlobalOp( + cir::GlobalOp op, mlir::ConversionPatternRewriter &rewriter) const { + assert(!cir::MissingFeatures::convertTypeForMemory()); + const mlir::Type llvmType = getTypeConverter()->convertType(op.getSymType()); + + // FIXME: These default values are placeholders until the the equivalent + // attributes are available on cir.global ops. This duplicates code + // in CIRToLLVMGlobalOpLowering::matchAndRewrite() but that will go + // away when the placeholders are no longer needed. + assert(!cir::MissingFeatures::opGlobalConstant()); + const bool isConst = false; + assert(!cir::MissingFeatures::addressSpace()); + const unsigned addrSpace = 0; + assert(!cir::MissingFeatures::opGlobalDSOLocal()); + const bool isDsoLocal = true; + assert(!cir::MissingFeatures::opGlobalThreadLocal()); + const bool isThreadLocal = false; + assert(!cir::MissingFeatures::opGlobalAlignment()); + const uint64_t alignment = 0; + assert(!cir::MissingFeatures::opGlobalLinkage()); + const mlir::LLVM::Linkage linkage = mlir::LLVM::Linkage::External; + const StringRef symbol = op.getSymName(); + + SmallVector attributes; + auto newGlobalOp = rewriter.replaceOpWithNewOp( + op, llvmType, isConst, linkage, symbol, nullptr, alignment, addrSpace, + isDsoLocal, isThreadLocal, + /*comdat=*/mlir::SymbolRefAttr(), attributes); + newGlobalOp.getRegion().push_back(new mlir::Block()); + rewriter.setInsertionPointToEnd(newGlobalOp.getInitializerBlock()); +}; + +mlir::LogicalResult +CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal( + cir::GlobalOp op, mlir::Attribute init, + mlir::ConversionPatternRewriter &rewriter) const { + // TODO: Generalize this handling when more types are needed here. + assert(isa(init)); + + // TODO(cir): once LLVM's dialect has proper equivalent attributes this + // should be updated. For now, we use a custom op to initialize globals + // to the appropriate value. + const mlir::Location loc = op.getLoc(); + setupRegionInitializedLLVMGlobalOp(op, rewriter); + CIRAttrToValue attrVisitor(op, rewriter, typeConverter); + mlir::Value value = attrVisitor.lowerCirAttrAsValue(init); + rewriter.create(loc, value); + return mlir::success(); +} + mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite( cir::GlobalOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { + std::optional init = op.getInitialValue(); + + // If we have an initializer and it requires region initialization, handle + // that separately + if (init.has_value() && attrRequiresRegionInitialization(init.value())) { + return matchAndRewriteRegionInitializedGlobal(op, init.value(), rewriter); + } + // Fetch required values to create LLVM op. const mlir::Type cirSymType = op.getSymType(); // This is the LLVM dialect type. + assert(!cir::MissingFeatures::convertTypeForMemory()); const mlir::Type llvmType = getTypeConverter()->convertType(cirSymType); // FIXME: These default values are placeholders until the the equivalent // attributes are available on cir.global ops. @@ -79,20 +212,15 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite( assert(!cir::MissingFeatures::opGlobalLinkage()); const mlir::LLVM::Linkage linkage = mlir::LLVM::Linkage::External; const StringRef symbol = op.getSymName(); - std::optional init = op.getInitialValue(); - SmallVector attributes; if (init.has_value()) { - if (const auto fltAttr = mlir::dyn_cast(init.value())) { - // Initializer is a constant floating-point number: convert to MLIR - // builtin constant. - init = rewriter.getFloatAttr(llvmType, fltAttr.getValue()); - } else if (const auto intAttr = - mlir::dyn_cast(init.value())) { - // Initializer is a constant array: convert it to a compatible llvm init. - init = rewriter.getIntegerAttr(llvmType, intAttr.getValue()); - } else { + GlobalInitAttrRewriter initRewriter(llvmType, rewriter); + init = initRewriter.rewriteInitAttr(init.value()); + // If initRewriter returned a null attribute, init will have a value but + // the value will be null. If that happens, initRewriter didn't handle the + // attribute type. It probably needs to be added to GlobalInitAttrRewriter. + if (!init.value()) { op.emitError() << "unsupported initializer '" << init.value() << "'"; return mlir::failure(); } @@ -109,6 +237,13 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite( static void prepareTypeConverter(mlir::LLVMTypeConverter &converter, mlir::DataLayout &dataLayout) { + converter.addConversion([&](cir::PointerType type) -> mlir::Type { + // Drop pointee type since LLVM dialect only allows opaque pointers. + assert(!cir::MissingFeatures::addressSpace()); + unsigned targetAS = 0; + + return mlir::LLVM::LLVMPointerType::get(type.getContext(), targetAS); + }); converter.addConversion([&](cir::IntType type) -> mlir::Type { // LLVM doesn't work with signed types, so we drop the CIR signs here. return mlir::IntegerType::get(type.getContext(), type.getWidth()); diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h index 6167ff39b5ad6..b3366c1fb9337 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h @@ -34,6 +34,16 @@ class CIRToLLVMGlobalOpLowering mlir::LogicalResult matchAndRewrite(cir::GlobalOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override; + +private: + bool attrRequiresRegionInitialization(mlir::Attribute attr) const; + + mlir::LogicalResult matchAndRewriteRegionInitializedGlobal( + cir::GlobalOp op, mlir::Attribute init, + mlir::ConversionPatternRewriter &rewriter) const; + + void setupRegionInitializedLLVMGlobalOp( + cir::GlobalOp op, mlir::ConversionPatternRewriter &rewriter) const; }; } // namespace direct diff --git a/clang/test/CIR/Lowering/global-var-simple.cpp b/clang/test/CIR/Lowering/global-var-simple.cpp index 06050e409d544..a5adb4011931a 100644 --- a/clang/test/CIR/Lowering/global-var-simple.cpp +++ b/clang/test/CIR/Lowering/global-var-simple.cpp @@ -79,3 +79,24 @@ long double ld; __float128 f128; // CHECK: @f128 = external dso_local global fp128 + +void *vp; +// CHECK: @vp = external dso_local global ptr{{$}} + +int *ip = 0; +// CHECK: @ip = dso_local global ptr null + +double *dp; +// CHECK: @dp = external dso_local global ptr{{$}} + +char **cpp; +// CHECK: @cpp = external dso_local global ptr{{$}} + +void (*fp)(); +// CHECK: @fp = external dso_local global ptr{{$}} + +int (*fpii)(int) = 0; +// CHECK: @fpii = dso_local global ptr null + +void (*fpvar)(int, ...); +// CHECK: @fpvar = external dso_local global ptr{{$}}