@@ -44,6 +44,12 @@ static llvm::cl::opt<bool>
4444 llvm::cl::desc (" Replace emission of malloc/free by alloca" ),
4545 llvm::cl::init(false ));
4646
47+ static llvm::cl::opt<bool > clUseBarePtrCallConv (
48+ PASS_NAME " -use-bare-ptr-memref-call-conv" ,
49+ llvm::cl::desc (" Replace FuncOp's MemRef arguments with "
50+ " bare pointers to the MemRef element types" ),
51+ llvm::cl::init(false ));
52+
4753LLVMTypeConverter::LLVMTypeConverter (MLIRContext *ctx)
4854 : llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()) {
4955 assert (llvmDialect && " LLVM IR dialect is not registered" );
@@ -239,6 +245,60 @@ Type LLVMTypeConverter::convertStandardType(Type t) {
239245 .Default ([](Type) { return Type (); });
240246}
241247
248+ // Converts function signature following LLVMTypeConverter approach but
249+ // replacing the type of MemRef arguments with a bare LLVM pointer to
250+ // the MemRef element type.
251+ LLVM::LLVMType BarePtrTypeConverter::convertFunctionSignature (
252+ FunctionType type, bool isVariadic,
253+ LLVMTypeConverter::SignatureConversion &result) {
254+ // Convert argument types one by one and check for errors.
255+ for (auto &en : llvm::enumerate (type.getInputs ())) {
256+ Type type = en.value ();
257+ Type converted;
258+ if (auto memrefTy = type.dyn_cast <MemRefType>())
259+ converted = convertMemRefTypeToBarePtr (memrefTy)
260+ .dyn_cast_or_null <LLVM::LLVMType>();
261+ else
262+ converted = convertType (type).dyn_cast_or_null <LLVM::LLVMType>();
263+
264+ if (!converted)
265+ return {};
266+ result.addInputs (en.index (), converted);
267+ }
268+
269+ SmallVector<LLVM::LLVMType, 8 > argTypes;
270+ argTypes.reserve (llvm::size (result.getConvertedTypes ()));
271+ for (Type type : result.getConvertedTypes ())
272+ argTypes.push_back (unwrap (type));
273+
274+ // If function does not return anything, create the void result type, if it
275+ // returns on element, convert it, otherwise pack the result types into a
276+ // struct.
277+ LLVM::LLVMType resultType =
278+ type.getNumResults () == 0
279+ ? LLVM::LLVMType::getVoidTy (llvmDialect)
280+ : unwrap (packFunctionResults (type.getResults ()));
281+ if (!resultType)
282+ return {};
283+ return LLVM::LLVMType::getFunctionTy (resultType, argTypes, isVariadic);
284+ }
285+
286+ // Converts MemRefType to a bare LLVM pointer to the MemRef element type.
287+ Type BarePtrTypeConverter::convertMemRefTypeToBarePtr (MemRefType type) {
288+ int64_t offset;
289+ SmallVector<int64_t , 4 > strides;
290+ bool strideSuccess = succeeded (getStridesAndOffset (type, strides, offset));
291+ assert (strideSuccess &&
292+ " Non-strided layout maps must have been normalized away" );
293+ (void )strideSuccess;
294+
295+ LLVM::LLVMType elementType = unwrap (convertType (type.getElementType ()));
296+ if (!elementType)
297+ return {};
298+ auto ptrTy = elementType.getPointerTo (type.getMemorySpace ());
299+ return ptrTy;
300+ }
301+
242302LLVMOpLowering::LLVMOpLowering (StringRef rootOpName, MLIRContext *context,
243303 LLVMTypeConverter &lowering_,
244304 PatternBenefit benefit)
@@ -548,7 +608,84 @@ struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> {
548608 for (unsigned idx : promotedArgIndices) {
549609 BlockArgument arg = firstBlock->getArgument (idx);
550610 Value loaded = rewriter.create <LLVM::LoadOp>(funcOp.getLoc (), arg);
551- rewriter.replaceUsesOfBlockArgument (arg, loaded);
611+ rewriter.replaceUsesOfWith (arg, loaded);
612+ }
613+ }
614+
615+ rewriter.eraseOp (op);
616+ return matchSuccess ();
617+ }
618+ };
619+
620+ // FuncOp conversion that converts MemRef arguments to bare pointers to the type
621+ // of the MemRef.
622+ struct BarePtrFuncOpConversion : public LLVMLegalizationPattern <FuncOp> {
623+ using LLVMLegalizationPattern<FuncOp>::LLVMLegalizationPattern;
624+
625+ PatternMatchResult
626+ matchAndRewrite (Operation *op, ArrayRef<Value> operands,
627+ ConversionPatternRewriter &rewriter) const override {
628+ auto funcOp = cast<FuncOp>(op);
629+ FunctionType type = funcOp.getType ();
630+ auto funcLoc = funcOp.getLoc ();
631+
632+ // Store the positions of memref-typed arguments so that we can promote them
633+ // to MemRef descriptor structs at the beginning of the function.
634+ SmallVector<std::pair<unsigned , Type>, 4 > promotedArgIndices;
635+ promotedArgIndices.reserve (type.getNumInputs ());
636+ for (auto en : llvm::enumerate (type.getInputs ())) {
637+ if (en.value ().isa <MemRefType>())
638+ promotedArgIndices.push_back ({en.index (), en.value ()});
639+ }
640+
641+ // Convert the original function arguments. MemRef types are lowered to bare
642+ // pointers to the MemRef element type.
643+ auto varargsAttr = funcOp.getAttrOfType <BoolAttr>(" std.varargs" );
644+ TypeConverter::SignatureConversion result (funcOp.getNumArguments ());
645+ auto llvmType = lowering.convertFunctionSignature (
646+ funcOp.getType (), varargsAttr && varargsAttr.getValue (), result);
647+
648+ // Only retain those attributes that are not constructed by build.
649+ SmallVector<NamedAttribute, 4 > attributes;
650+ for (const auto &attr : funcOp.getAttrs ()) {
651+ if (attr.first .is (SymbolTable::getSymbolAttrName ()) ||
652+ attr.first .is (impl::getTypeAttrName ()) ||
653+ attr.first .is (" std.varargs" ))
654+ continue ;
655+ attributes.push_back (attr);
656+ }
657+
658+ // Create an LLVM function, use external linkage by default until MLIR
659+ // functions have linkage.
660+ auto newFuncOp =
661+ rewriter.create <LLVM::LLVMFuncOp>(funcLoc, funcOp.getName (), llvmType,
662+ LLVM::Linkage::External, attributes);
663+ rewriter.inlineRegionBefore (funcOp.getBody (), newFuncOp.getBody (),
664+ newFuncOp.end ());
665+
666+ // Tell the rewriter to convert the region signature.
667+ rewriter.applySignatureConversion (&newFuncOp.getBody (), result);
668+
669+ // Promote bare pointers from MemRef arguments to a MemRef descriptor struct
670+ // at the beginning of the function so that all the MemRefs in the function
671+ // have a uniform representation.
672+ if (!newFuncOp.getBody ().empty ()) {
673+ Block *firstBlock = &newFuncOp.getBody ().front ();
674+ rewriter.setInsertionPoint (firstBlock, firstBlock->begin ());
675+ for (auto argIdxTypePair : promotedArgIndices) {
676+ // Replace argument with a placeholder (undef), promote argument to a
677+ // MemRef descriptor and replace placeholder with the last instruction
678+ // of the MemRef descriptor. The placeholder is needed to avoid
679+ // replacing argument uses in the MemRef descriptor instructions.
680+ BlockArgument arg = firstBlock->getArgument (argIdxTypePair.first );
681+ Value placeHolder =
682+ rewriter.create <LLVM::UndefOp>(funcLoc, arg.getType ());
683+ rewriter.replaceUsesOfWith (arg, placeHolder);
684+ auto desc = MemRefDescriptor::fromStaticShape (
685+ rewriter, funcLoc, lowering,
686+ argIdxTypePair.second .cast <MemRefType>(), arg);
687+ rewriter.replaceUsesOfWith (placeHolder, desc);
688+ placeHolder.getDefiningOp ()->erase ();
552689 }
553690 }
554691
@@ -2126,7 +2263,6 @@ void mlir::populateStdToLLVMMemoryConversionPatters(
21262263 // clang-format off
21272264 patterns.insert <
21282265 DimOpLowering,
2129- FuncOpConversion,
21302266 LoadOpLowering,
21312267 MemRefCastOpLowering,
21322268 StoreOpLowering,
@@ -2139,8 +2275,26 @@ void mlir::populateStdToLLVMMemoryConversionPatters(
21392275 // clang-format on
21402276}
21412277
2278+ void mlir::populateStdToLLVMDefaultFuncOpConversionPattern (
2279+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
2280+ patterns.insert <FuncOpConversion>(*converter.getDialect (), converter);
2281+ }
2282+
21422283void mlir::populateStdToLLVMConversionPatterns (
21432284 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
2285+ populateStdToLLVMDefaultFuncOpConversionPattern (converter, patterns);
2286+ populateStdToLLVMNonMemoryConversionPatterns (converter, patterns);
2287+ populateStdToLLVMMemoryConversionPatters (converter, patterns);
2288+ }
2289+
2290+ void mlir::populateStdToLLVMBarePtrFuncOpConversionPattern (
2291+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
2292+ patterns.insert <BarePtrFuncOpConversion>(*converter.getDialect (), converter);
2293+ }
2294+
2295+ void mlir::populateStdToLLVMBarePtrConversionPatterns (
2296+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
2297+ populateStdToLLVMBarePtrFuncOpConversionPattern (converter, patterns);
21442298 populateStdToLLVMNonMemoryConversionPatterns (converter, patterns);
21452299 populateStdToLLVMMemoryConversionPatters (converter, patterns);
21462300}
@@ -2210,6 +2364,12 @@ makeStandardToLLVMTypeConverter(MLIRContext *context) {
22102364 return std::make_unique<LLVMTypeConverter>(context);
22112365}
22122366
2367+ // / Create an instance of BarePtrTypeConverter in the given context.
2368+ static std::unique_ptr<LLVMTypeConverter>
2369+ makeStandardToLLVMBarePtrTypeConverter (MLIRContext *context) {
2370+ return std::make_unique<BarePtrTypeConverter>(context);
2371+ }
2372+
22132373namespace {
22142374// / A pass converting MLIR operations into the LLVM IR dialect.
22152375struct LLVMLoweringPass : public ModulePass <LLVMLoweringPass> {
@@ -2274,6 +2434,9 @@ static PassRegistration<LLVMLoweringPass>
22742434 " Standard to the LLVM dialect" ,
22752435 [] {
22762436 return std::make_unique<LLVMLoweringPass>(
2277- clUseAlloca.getValue (), populateStdToLLVMConversionPatterns,
2278- makeStandardToLLVMTypeConverter);
2437+ clUseAlloca.getValue (),
2438+ clUseBarePtrCallConv ? populateStdToLLVMBarePtrConversionPatterns
2439+ : populateStdToLLVMConversionPatterns,
2440+ clUseBarePtrCallConv ? makeStandardToLLVMBarePtrTypeConverter
2441+ : makeStandardToLLVMTypeConverter);
22792442 });
0 commit comments