@@ -2556,6 +2556,101 @@ class ConvertAllocOpToGpuRuntimeCallPattern
25562556 }
25572557};
25582558
2559+ class ConvertOccupancyOp
2560+ : public ConvertOpToGpuRuntimeCallPattern<enzymexla::GPUOccupancyOp> {
2561+ public:
2562+ // / The attribute name to use instead of `gpu.kernel`.
2563+ StringRef backend;
2564+
2565+ ConvertOccupancyOp (LLVMTypeConverter &typeConverter, StringRef backend)
2566+ : ConvertOpToGpuRuntimeCallPattern<enzymexla::GPUOccupancyOp>(
2567+ typeConverter),
2568+ backend (backend) {}
2569+
2570+ private:
2571+ LogicalResult
2572+ matchAndRewrite (enzymexla::GPUOccupancyOp op, OpAdaptor adaptor,
2573+ ConversionPatternRewriter &rewriter) const override {
2574+
2575+ if (failed (areAllLLVMTypes (op, adaptor.getOperands (), rewriter)))
2576+ return failure ();
2577+
2578+ if (backend != " cuda" )
2579+ return rewriter.notifyMatchFailure (
2580+ op, " Occupancy op lowering only supported for CUDA" );
2581+
2582+ auto moduleOp = op->getParentOfType <ModuleOp>();
2583+ auto i64 = rewriter.getIntegerType (64 );
2584+ auto i32 = rewriter.getIntegerType (32 );
2585+
2586+ auto intty = adaptor.getBlockSize ().getType ();
2587+ auto loc = op.getLoc ();
2588+
2589+ auto ptrty = LLVM::LLVMPointerType::get (rewriter.getContext ());
2590+ Type tys[] = {ptrty, ptrty, intty, adaptor.getDynamicSMemSize ().getType (),
2591+ adaptor.getFlags ().getType ()};
2592+
2593+ auto cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlagsFn =
2594+ LLVM::lookupOrCreateFn (
2595+ rewriter, moduleOp,
2596+ " cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags" , tys, i32 );
2597+ if (failed (cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlagsFn)) {
2598+ llvm::errs () << " cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags "
2599+ " already exists with different types\n " ;
2600+ return failure ();
2601+ }
2602+
2603+ auto one = rewriter.create <LLVM::ConstantOp>(loc, i64 ,
2604+ rewriter.getI64IntegerAttr (1 ));
2605+
2606+ auto ptr = rewriter.create <LLVM::AllocaOp>(loc, ptrty, intty, one);
2607+
2608+ std::string funcStubName =
2609+ getFuncStubName (op.getFn ().getRootReference ().getValue (),
2610+ op.getFn ().getLeafReference ().getValue ());
2611+ auto addr = rewriter.create <LLVM::AddressOfOp>(loc, ptrty, funcStubName);
2612+ Value args[] = {ptr, addr, adaptor.getBlockSize (),
2613+ adaptor.getDynamicSMemSize (), adaptor.getFlags ()};
2614+ rewriter.create <LLVM::CallOp>(
2615+ loc, cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlagsFn.value (),
2616+ args);
2617+ rewriter.replaceOpWithNewOp <LLVM::LoadOp>(op, intty, ptr);
2618+
2619+ return success ();
2620+ }
2621+ };
2622+
2623+ class ConvertGPUKernelAddressOp
2624+ : public ConvertOpToGpuRuntimeCallPattern<enzymexla::GPUKernelAddressOp> {
2625+ public:
2626+ // / The attribute name to use instead of `gpu.kernel`.
2627+ StringRef backend;
2628+
2629+ ConvertGPUKernelAddressOp (LLVMTypeConverter &typeConverter, StringRef backend)
2630+ : ConvertOpToGpuRuntimeCallPattern<enzymexla::GPUKernelAddressOp>(
2631+ typeConverter),
2632+ backend (backend) {}
2633+
2634+ private:
2635+ LogicalResult
2636+ matchAndRewrite (enzymexla::GPUKernelAddressOp op, OpAdaptor adaptor,
2637+ ConversionPatternRewriter &rewriter) const override {
2638+
2639+ if (backend != " cuda" )
2640+ return rewriter.notifyMatchFailure (
2641+ op, " KernelAddress lowering only supported for CUDA" );
2642+
2643+ std::string funcStubName =
2644+ getFuncStubName (op.getFn ().getRootReference ().getValue (),
2645+ op.getFn ().getLeafReference ().getValue ());
2646+
2647+ rewriter.replaceOpWithNewOp <LLVM::AddressOfOp>(op, op.getType (),
2648+ funcStubName);
2649+
2650+ return success ();
2651+ }
2652+ };
2653+
25592654// / A rewrite pattern to convert gpu.alloc operations into a GPU runtime
25602655// / call. Currently it supports CUDA, CPU, and XLA.
25612656template <bool cStyle>
@@ -3938,6 +4033,10 @@ struct ConvertPolygeistToLLVMPass
39384033 // /*kernelIntersperseSizeCallConv*/ false);
39394034 patterns.add <ConvertAllocOpToGpuRuntimeCallPattern<true >>(converter,
39404035 gpuTarget);
4036+ patterns.add <ConvertOccupancyOp>(converter, gpuTarget);
4037+
4038+ patterns.add <ConvertGPUKernelAddressOp>(converter, gpuTarget);
4039+
39414040 patterns.add <ConvertDeallocOpToGpuRuntimeCallPattern<true >>(converter,
39424041 gpuTarget);
39434042 patterns.add <ConvertXLAWrapperPattern<true >>(converter, gpuTarget);
0 commit comments