3838#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
3939#include " mlir/Dialect/LLVMIR/LLVMTypes.h"
4040#include " mlir/Dialect/LLVMIR/NVVMDialect.h"
41+ #include " mlir/Dialect/LLVMIR/ROCDLDialect.h"
4142#include " mlir/Dialect/MemRef/IR/MemRef.h"
4243#include " mlir/Dialect/OpenMP/OpenMPDialect.h"
4344#include " mlir/Dialect/SCF/IR/SCF.h"
@@ -1718,8 +1719,18 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
17181719 ctorBuilder.create <LLVM::AddressOfOp>(loc, fatBinWrapper);
17191720 auto bitcastOfWrapper = ctorBuilder.create <LLVM::BitcastOp>(
17201721 loc, llvmPointerType, addressOfWrapper);
1721- auto module = rtRegisterFatBinaryCallBuilder.create (loc, ctorBuilder,
1722- {bitcastOfWrapper});
1722+
1723+ auto cudaRegisterFatbinFn = LLVM::lookupOrCreateFn (
1724+ rewriter, moduleOp, " __cudaRegisterFatBinary" , llvmPointerType,
1725+ llvmPointerType);
1726+ if (failed (cudaRegisterFatbinFn)) {
1727+ llvm::errs () << " cudamalloc already exists with different types\n " ;
1728+ return failure ();
1729+ }
1730+
1731+ auto module = rewriter.create <LLVM::CallOp>(
1732+ loc, cudaRegisterFatbinFn.value (), ValueRange (bitcastOfWrapper));
1733+
17231734 auto moduleGlobalName =
17241735 std::string (llvm::formatv (" polygeist_{0}_module_ptr" , moduleName));
17251736 {
@@ -1771,12 +1782,32 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
17711782 auto aoo = ctorBuilder.create <LLVM::AddressOfOp>(loc, stub);
17721783 auto bitcast =
17731784 ctorBuilder.create <LLVM::BitcastOp>(loc, llvmPointerType, aoo);
1774- auto ret = rtRegisterFunctionCallBuilder.create (
1775- loc, ctorBuilder,
1776- {module .getResult (), bitcast, kernelName, kernelName,
1777- /* TODO I have no idea what the following params are */
1778- ctorBuilder.create <LLVM::ConstantOp>(loc, llvmInt32Type, -1 ),
1779- nullPtr, nullPtr, nullPtr, nullPtr, nullPtr});
1785+
1786+ Type tys[] = {llvmPointerType, llvmPointerType, llvmPointerType,
1787+ llvmPointerType, llvmInt32Type, llvmPointerType,
1788+ llvmPointerType, llvmPointerType, llvmPointerType,
1789+ llvmPointerType};
1790+ auto cudaRegisterFn = LLVM::lookupOrCreateFn (
1791+ rewriter, moduleOp, " __cudaRegisterFunction" , tys,
1792+ llvmInt32Type);
1793+ if (failed (cudaRegisterFn)) {
1794+ llvm::errs ()
1795+ << " cudamalloc already exists with different types\n " ;
1796+ return failure ();
1797+ }
1798+ Value args[] = {
1799+ module .getResult (),
1800+ bitcast,
1801+ kernelName,
1802+ kernelName,
1803+ ctorBuilder.create <LLVM::ConstantOp>(loc, llvmInt32Type, -1 ),
1804+ nullPtr,
1805+ nullPtr,
1806+ nullPtr,
1807+ nullPtr,
1808+ nullPtr};
1809+
1810+ rewriter.create <LLVM::CallOp>(loc, cudaRegisterFn.value (), args);
17801811 } else if (LLVM::GlobalOp g = dyn_cast<LLVM::GlobalOp>(op)) {
17811812 int addrSpace = g.getAddrSpace ();
17821813 if (addrSpace != 1 /* device */ && addrSpace != 4 /* constant */ )
@@ -1825,9 +1856,18 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
18251856 }
18261857 }
18271858 // TODO this has to happen only for some CUDA versions
1828- if (gpuTarget == " cuda" )
1829- rtRegisterFatBinaryEndCallBuilder.create (loc, ctorBuilder,
1830- {module .getResult ()});
1859+ if (gpuTarget == " cuda" ) {
1860+ auto cudaRegisterFatbinFn = LLVM::lookupOrCreateFn (
1861+ rewriter, moduleOp, " __cudaRegisterFatBinaryEnd" , llvmPointerType,
1862+ llvmVoidType);
1863+ if (failed (cudaRegisterFatbinFn)) {
1864+ llvm::errs () << " cudamalloc already exists with different types\n " ;
1865+ return failure ();
1866+ }
1867+
1868+ rewriter.create <LLVM::CallOp>(loc, cudaRegisterFatbinFn.value (),
1869+ ValueRange (module ->getResult (0 )));
1870+ }
18311871 ctorBuilder.create <LLVM::ReturnOp>(loc, ValueRange ());
18321872 }
18331873 auto ctorSymbol = FlatSymbolRefAttr::get (ctor);
@@ -1847,8 +1887,17 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
18471887 auto aoo = dtorBuilder.create <LLVM::AddressOfOp>(loc, moduleGlobal);
18481888 auto module = dtorBuilder.create <LLVM::LoadOp>(
18491889 loc, llvmPointerPointerType, aoo->getResult (0 ));
1850- rtUnregisterFatBinaryCallBuilder.create (loc, dtorBuilder,
1851- module .getResult ());
1890+
1891+ auto cudaUnRegisterFatbinFn = LLVM::lookupOrCreateFn (
1892+ rewriter, moduleOp, " __cudaUnregisterFatBinary" , llvmPointerType,
1893+ llvmVoidType);
1894+ if (failed (cudaUnRegisterFatbinFn)) {
1895+ llvm::errs () << " cudamalloc already exists with different types\n " ;
1896+ return failure ();
1897+ }
1898+
1899+ rewriter.create <LLVM::CallOp>(loc, cudaUnRegisterFatbinFn.value (),
1900+ ValueRange (module ));
18521901 dtorBuilder.create <LLVM::ReturnOp>(loc, ValueRange ());
18531902 auto dtorSymbol = FlatSymbolRefAttr::get (dtor);
18541903 {
@@ -2469,6 +2518,34 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
24692518 }
24702519};
24712520
2521+ // / Pattern for returning from a function, packs the results into a struct.
2522+ struct GPUReturnOpLowering : public ConvertOpToLLVMPattern <gpu::ReturnOp> {
2523+ public:
2524+ using ConvertOpToLLVMPattern<gpu::ReturnOp>::ConvertOpToLLVMPattern;
2525+
2526+ LogicalResult
2527+ matchAndRewrite (gpu::ReturnOp returnOp, OpAdaptor adaptor,
2528+ ConversionPatternRewriter &rewriter) const override {
2529+ if (returnOp->getNumOperands () <= 1 ) {
2530+ rewriter.replaceOpWithNewOp <LLVM::ReturnOp>(returnOp,
2531+ adaptor.getOperands ());
2532+ return success ();
2533+ }
2534+
2535+ auto returnedType = LLVM::LLVMStructType::getLiteral (
2536+ returnOp->getContext (),
2537+ llvm::to_vector (adaptor.getOperands ().getTypes ()));
2538+ Value packed =
2539+ rewriter.create <LLVM::UndefOp>(returnOp->getLoc (), returnedType);
2540+ for (const auto &[index, value] : llvm::enumerate (adaptor.getOperands ())) {
2541+ packed = rewriter.create <LLVM::InsertValueOp>(returnOp->getLoc (), packed,
2542+ value, index);
2543+ }
2544+ rewriter.replaceOpWithNewOp <LLVM::ReturnOp>(returnOp, packed);
2545+ return success ();
2546+ }
2547+ };
2548+
24722549// / TODO: Temporary until we migrate everything to opaque pointers
24732550struct ReconcileUnrealizedPointerCasts
24742551 : public OpRewritePattern<UnrealizedConversionCastOp> {
@@ -2558,6 +2635,23 @@ populateCStyleMemRefLoweringPatterns(RewritePatternSet &patterns,
25582635 patterns.add <CMemcpyOpLowering>(typeConverter);
25592636}
25602637
2638+ // / Appends the patterns lowering operations from the Func dialect to the LLVM
2639+ // / dialect using the C-style type conversion, i.e. converting memrefs to
2640+ // / pointer to arrays of arrays.
2641+ static void
2642+ populateCStyleGPUFuncLoweringPatterns (RewritePatternSet &patterns,
2643+ LLVMTypeConverter &typeConverter,
2644+ std::string gpuTarget) {
2645+ patterns.add <GPUReturnOpLowering>(typeConverter);
2646+ patterns.add <GPUFuncOpLowering>(
2647+ typeConverter,
2648+ /* allocaAddrSpace=*/ 0 ,
2649+ StringAttr::get (&typeConverter.getContext (),
2650+ gpuTarget == " cuda"
2651+ ? NVVM::NVVMDialect::getKernelFuncAttrName ()
2652+ : ROCDL::ROCDLDialect::getKernelFuncAttrName ()));
2653+ }
2654+
25612655// / Appends the patterns lowering operations from the Func dialect to the LLVM
25622656// / dialect using the C-style type conversion, i.e. converting memrefs to
25632657// / pointer to arrays of arrays.
@@ -2618,6 +2712,13 @@ struct ConvertPolygeistToLLVMPass
26182712
26192713 RewritePatternSet patterns (&getContext ());
26202714
2715+ auto gpuTarget = " cuda" ;
2716+
2717+ // Insert our custom version of GPUFuncLowering
2718+ if (useCStyleMemRef) {
2719+ populateCStyleGPUFuncLoweringPatterns (patterns, converter, gpuTarget);
2720+ }
2721+
26212722 populatePolygeistToLLVMConversionPatterns (converter, patterns);
26222723 populateSCFToControlFlowConversionPatterns (patterns);
26232724 // populateForBreakToWhilePatterns(patterns);
@@ -2642,7 +2743,6 @@ struct ConvertPolygeistToLLVMPass
26422743
26432744 // Our custom versions of the gpu patterns
26442745 if (useCStyleMemRef) {
2645- auto gpuTarget = " cuda" ;
26462746 patterns.add <ConvertLaunchFuncOpToGpuRuntimeCallPattern>(
26472747 converter, " gpu.binary" , gpuTarget);
26482748 // patterns.add<LegalizeLaunchFuncOpPattern>(
0 commit comments