Skip to content

Commit 7f88655

Browse files
authored
Polygeist: handle occupancy and related utilities (#1395)
* WIP: occupance * wip * b * occ op * fmt * fix * fix * fix * fix * no comdat * More kernel support * fix * attr * Fix * fix * fix * fn attr * fix * fix * fix * fix
1 parent ccfcd69 commit 7f88655

File tree

9 files changed

+976
-211
lines changed

9 files changed

+976
-211
lines changed

src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,23 @@ def GetStreamOp : EnzymeXLA_Op<"get_stream", [Pure]> {
137137
}
138138

139139

140+
def GPUOccupancyOp : EnzymeXLA_Op<"gpu_occupancy", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
141+
let arguments = (ins
142+
SymbolRefAttr:$fn,
143+
AnyType:$blockSize,
144+
AnyType:$dynamicSMemSize,
145+
AnyType:$flags
146+
);
147+
let results = (outs AnyType : $result);
148+
}
149+
150+
def GPUKernelAddressOp : EnzymeXLA_Op<"gpu_kernel_address", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
151+
let arguments = (ins
152+
SymbolRefAttr:$fn
153+
);
154+
let results = (outs AnyType : $result);
155+
}
156+
140157
def GPUWrapperOp : EnzymeXLA_Op<"gpu_wrapper", [
141158
RecursiveMemoryEffects,
142159
AffineScope,

src/enzyme_ad/jax/Dialect/Ops.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,32 @@ static std::optional<int64_t> getConstant(Value v) {
6868
return {};
6969
}
7070

71+
LogicalResult
72+
GPUOccupancyOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
73+
// TODO: Verify that the result type is same as the type of the referenced
74+
// func.func op.
75+
auto global = symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
76+
*this, getFnAttr());
77+
if (!global)
78+
return emitOpError("'")
79+
<< getFn() << "' does not reference a valid global funcOp";
80+
81+
return success();
82+
}
83+
84+
LogicalResult
85+
GPUKernelAddressOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
86+
// TODO: Verify that the result type is same as the type of the referenced
87+
// func.func op.
88+
auto global = symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
89+
*this, getFnAttr());
90+
if (!global)
91+
return emitOpError("'")
92+
<< getFn() << "' does not reference a valid global funcOp";
93+
94+
return success();
95+
}
96+
7197
LogicalResult
7298
KernelCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
7399
// TODO: Verify that the result type is same as the type of the referenced

src/enzyme_ad/jax/Passes/ConvertParallelToGPU.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88
#include "mlir/Dialect/Affine/IR/AffineOps.h"
99
#include "mlir/Dialect/Arith/IR/Arith.h"
10+
#include "mlir/Dialect/DLTI/DLTI.h"
1011
#include "mlir/Dialect/Func/IR/FuncOps.h"
1112
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1213
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -2529,6 +2530,30 @@ gdgo->erase();
25292530
gmod.getContext(), /*optLevel*/ 2,
25302531
/*triple*/ "nvptx64-nvidia-cuda", chip, features);
25312532
gmod.setTargetsAttr(ArrayAttr::get(gmod.getContext(), target));
2533+
2534+
DataLayoutSpecInterface dataLayout = {};
2535+
// Set index type size to 32 bits
2536+
{
2537+
auto ctx = gmod.getContext();
2538+
llvm::DenseMap<mlir::TypeAttr, mlir::DataLayoutEntryInterface>
2539+
typeEntries;
2540+
auto type = IndexType::get(ctx);
2541+
auto key = mlir::TypeAttr::get(type);
2542+
uint64_t size = 32;
2543+
auto params =
2544+
IntegerAttr::get(mlir::IntegerType::get(ctx, 64), size);
2545+
typeEntries.try_emplace(key,
2546+
DataLayoutEntryAttr::get(type, params));
2547+
SmallVector<DataLayoutEntryInterface> entries;
2548+
entries.reserve(typeEntries.size());
2549+
for (const auto &it : typeEntries)
2550+
entries.push_back(it.second);
2551+
dataLayout = DataLayoutSpecAttr::get(ctx, entries);
2552+
}
2553+
// gpuModule->setAttr(
2554+
// LLVM::LLVMDialect::getDataLayoutAttrName(),
2555+
// deviceModule->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName()));
2556+
gmod->setAttr(DLTIDialect::kDataLayoutAttrName, dataLayout);
25322557
}
25332558
});
25342559
});

src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2556,6 +2556,101 @@ class ConvertAllocOpToGpuRuntimeCallPattern
25562556
}
25572557
};
25582558

2559+
class ConvertOccupancyOp
2560+
: public ConvertOpToGpuRuntimeCallPattern<enzymexla::GPUOccupancyOp> {
2561+
public:
2562+
/// The attribute name to use instead of `gpu.kernel`.
2563+
StringRef backend;
2564+
2565+
ConvertOccupancyOp(LLVMTypeConverter &typeConverter, StringRef backend)
2566+
: ConvertOpToGpuRuntimeCallPattern<enzymexla::GPUOccupancyOp>(
2567+
typeConverter),
2568+
backend(backend) {}
2569+
2570+
private:
2571+
LogicalResult
2572+
matchAndRewrite(enzymexla::GPUOccupancyOp op, OpAdaptor adaptor,
2573+
ConversionPatternRewriter &rewriter) const override {
2574+
2575+
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
2576+
return failure();
2577+
2578+
if (backend != "cuda")
2579+
return rewriter.notifyMatchFailure(
2580+
op, "Occupancy op lowering only supported for CUDA");
2581+
2582+
auto moduleOp = op->getParentOfType<ModuleOp>();
2583+
auto i64 = rewriter.getIntegerType(64);
2584+
auto i32 = rewriter.getIntegerType(32);
2585+
2586+
auto intty = adaptor.getBlockSize().getType();
2587+
auto loc = op.getLoc();
2588+
2589+
auto ptrty = LLVM::LLVMPointerType::get(rewriter.getContext());
2590+
Type tys[] = {ptrty, ptrty, intty, adaptor.getDynamicSMemSize().getType(),
2591+
adaptor.getFlags().getType()};
2592+
2593+
auto cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlagsFn =
2594+
LLVM::lookupOrCreateFn(
2595+
rewriter, moduleOp,
2596+
"cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", tys, i32);
2597+
if (failed(cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlagsFn)) {
2598+
llvm::errs() << " cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags "
2599+
"already exists with different types\n";
2600+
return failure();
2601+
}
2602+
2603+
auto one = rewriter.create<LLVM::ConstantOp>(loc, i64,
2604+
rewriter.getI64IntegerAttr(1));
2605+
2606+
auto ptr = rewriter.create<LLVM::AllocaOp>(loc, ptrty, intty, one);
2607+
2608+
std::string funcStubName =
2609+
getFuncStubName(op.getFn().getRootReference().getValue(),
2610+
op.getFn().getLeafReference().getValue());
2611+
auto addr = rewriter.create<LLVM::AddressOfOp>(loc, ptrty, funcStubName);
2612+
Value args[] = {ptr, addr, adaptor.getBlockSize(),
2613+
adaptor.getDynamicSMemSize(), adaptor.getFlags()};
2614+
rewriter.create<LLVM::CallOp>(
2615+
loc, cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlagsFn.value(),
2616+
args);
2617+
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, intty, ptr);
2618+
2619+
return success();
2620+
}
2621+
};
2622+
2623+
class ConvertGPUKernelAddressOp
2624+
: public ConvertOpToGpuRuntimeCallPattern<enzymexla::GPUKernelAddressOp> {
2625+
public:
2626+
/// The attribute name to use instead of `gpu.kernel`.
2627+
StringRef backend;
2628+
2629+
ConvertGPUKernelAddressOp(LLVMTypeConverter &typeConverter, StringRef backend)
2630+
: ConvertOpToGpuRuntimeCallPattern<enzymexla::GPUKernelAddressOp>(
2631+
typeConverter),
2632+
backend(backend) {}
2633+
2634+
private:
2635+
LogicalResult
2636+
matchAndRewrite(enzymexla::GPUKernelAddressOp op, OpAdaptor adaptor,
2637+
ConversionPatternRewriter &rewriter) const override {
2638+
2639+
if (backend != "cuda")
2640+
return rewriter.notifyMatchFailure(
2641+
op, "KernelAddress lowering only supported for CUDA");
2642+
2643+
std::string funcStubName =
2644+
getFuncStubName(op.getFn().getRootReference().getValue(),
2645+
op.getFn().getLeafReference().getValue());
2646+
2647+
rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, op.getType(),
2648+
funcStubName);
2649+
2650+
return success();
2651+
}
2652+
};
2653+
25592654
/// A rewrite pattern to convert gpu.alloc operations into a GPU runtime
25602655
/// call. Currently it supports CUDA, CPU, and XLA.
25612656
template <bool cStyle>
@@ -3938,6 +4033,10 @@ struct ConvertPolygeistToLLVMPass
39384033
// /*kernelIntersperseSizeCallConv*/ false);
39394034
patterns.add<ConvertAllocOpToGpuRuntimeCallPattern<true>>(converter,
39404035
gpuTarget);
4036+
patterns.add<ConvertOccupancyOp>(converter, gpuTarget);
4037+
4038+
patterns.add<ConvertGPUKernelAddressOp>(converter, gpuTarget);
4039+
39414040
patterns.add<ConvertDeallocOpToGpuRuntimeCallPattern<true>>(converter,
39424041
gpuTarget);
39434042
patterns.add<ConvertXLAWrapperPattern<true>>(converter, gpuTarget);

0 commit comments

Comments
 (0)