Skip to content

Commit ee040a2

Browse files
authored
continuing gpu backend (#1035)
* continuing * fmt * raise fix * fmt * fix * fmt * fix debug1 * debug info fix * fmt
1 parent df04e2f commit ee040a2

File tree

6 files changed

+263
-80
lines changed

6 files changed

+263
-80
lines changed

src/enzyme_ad/jax/Passes/ConvertParallelToGPU.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1541,6 +1541,11 @@ struct ParallelToGPULaunch : public OpRewritePattern<enzymexla::GPUWrapperOp> {
15411541

15421542
rewriter.setInsertionPoint(wrapper);
15431543
auto errOp = rewriter.create<enzymexla::GPUErrorOp>(loc);
1544+
1545+
for (auto atname : {"passthrough", "target_features"})
1546+
if (auto attr = wrapper->getAttr(atname)) {
1547+
errOp->setAttr(atname, attr);
1548+
}
15441549
rewriter.setInsertionPointToStart(errOp.getBody());
15451550
rewriter.eraseOp(wrapper.getBody()->getTerminator());
15461551
rewriter.inlineBlockBefore(wrapper.getBody(),
@@ -2238,6 +2243,51 @@ gdgo->erase();
22382243
signalPassFailure();
22392244
return;
22402245
}
2246+
SymbolTableCollection symbolTable;
2247+
symbolTable.getSymbolTable(getOperation());
2248+
getOperation()->walk([&](GPUErrorOp err) {
2249+
std::string sm;
2250+
if (auto attr =
2251+
dyn_cast_or_null<ArrayAttr>(err->getAttr("passthrough"))) {
2252+
for (auto a : attr) {
2253+
if (auto ar = dyn_cast<ArrayAttr>(a)) {
2254+
if (ar.size() != 2)
2255+
continue;
2256+
auto s0 = dyn_cast<StringAttr>(ar[0]);
2257+
auto s1 = dyn_cast<StringAttr>(ar[1]);
2258+
if (!s0 || !s1)
2259+
continue;
2260+
if (s0.getValue() == "target-cpu")
2261+
sm = s1.getValue();
2262+
}
2263+
}
2264+
}
2265+
std::string feat;
2266+
if (auto attr = dyn_cast_or_null<LLVM::TargetFeaturesAttr>(
2267+
err->getAttr("target_features"))) {
2268+
feat = attr.getFeaturesString();
2269+
}
2270+
2271+
err->walk([&](gpu::LaunchFuncOp launch) {
2272+
auto gfunc = dyn_cast_or_null<gpu::GPUFuncOp>(
2273+
symbolTable.lookupNearestSymbolFrom(launch, launch.getKernel()));
2274+
if (!gfunc)
2275+
return;
2276+
auto gmod = cast<gpu::GPUModuleOp>(gfunc->getParentOp());
2277+
if (!gmod.getTargetsAttr()) {
2278+
auto chip = sm;
2279+
if (chip.size() == 0)
2280+
chip = "sm_50";
2281+
auto features = feat;
2282+
if (features.size() == 0)
2283+
features = "+ptx60";
2284+
auto target = NVVM::NVVMTargetAttr::get(
2285+
gmod.getContext(), /*optLevel*/ 2,
2286+
/*triple*/ "nvptx64-nvidia-cuda", chip, features);
2287+
gmod.setTargetsAttr(ArrayAttr::get(gmod.getContext(), target));
2288+
}
2289+
});
2290+
});
22412291
}
22422292
};
22432293

src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp

Lines changed: 114 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
3939
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
4040
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
41+
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
4142
#include "mlir/Dialect/MemRef/IR/MemRef.h"
4243
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
4344
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -1718,8 +1719,18 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
17181719
ctorBuilder.create<LLVM::AddressOfOp>(loc, fatBinWrapper);
17191720
auto bitcastOfWrapper = ctorBuilder.create<LLVM::BitcastOp>(
17201721
loc, llvmPointerType, addressOfWrapper);
1721-
auto module = rtRegisterFatBinaryCallBuilder.create(loc, ctorBuilder,
1722-
{bitcastOfWrapper});
1722+
1723+
auto cudaRegisterFatbinFn = LLVM::lookupOrCreateFn(
1724+
rewriter, moduleOp, "__cudaRegisterFatBinary", llvmPointerType,
1725+
llvmPointerType);
1726+
if (failed(cudaRegisterFatbinFn)) {
1727+
llvm::errs() << " cudamalloc already exists with different types\n";
1728+
return failure();
1729+
}
1730+
1731+
auto module = rewriter.create<LLVM::CallOp>(
1732+
loc, cudaRegisterFatbinFn.value(), ValueRange(bitcastOfWrapper));
1733+
17231734
auto moduleGlobalName =
17241735
std::string(llvm::formatv("polygeist_{0}_module_ptr", moduleName));
17251736
{
@@ -1771,12 +1782,32 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
17711782
auto aoo = ctorBuilder.create<LLVM::AddressOfOp>(loc, stub);
17721783
auto bitcast =
17731784
ctorBuilder.create<LLVM::BitcastOp>(loc, llvmPointerType, aoo);
1774-
auto ret = rtRegisterFunctionCallBuilder.create(
1775-
loc, ctorBuilder,
1776-
{module.getResult(), bitcast, kernelName, kernelName,
1777-
/* TODO I have no idea what the following params are */
1778-
ctorBuilder.create<LLVM::ConstantOp>(loc, llvmInt32Type, -1),
1779-
nullPtr, nullPtr, nullPtr, nullPtr, nullPtr});
1785+
1786+
Type tys[] = {llvmPointerType, llvmPointerType, llvmPointerType,
1787+
llvmPointerType, llvmInt32Type, llvmPointerType,
1788+
llvmPointerType, llvmPointerType, llvmPointerType,
1789+
llvmPointerType};
1790+
auto cudaRegisterFn = LLVM::lookupOrCreateFn(
1791+
rewriter, moduleOp, "__cudaRegisterFunction", tys,
1792+
llvmInt32Type);
1793+
if (failed(cudaRegisterFn)) {
1794+
llvm::errs()
1795+
<< " cudamalloc already exists with different types\n";
1796+
return failure();
1797+
}
1798+
Value args[] = {
1799+
module.getResult(),
1800+
bitcast,
1801+
kernelName,
1802+
kernelName,
1803+
ctorBuilder.create<LLVM::ConstantOp>(loc, llvmInt32Type, -1),
1804+
nullPtr,
1805+
nullPtr,
1806+
nullPtr,
1807+
nullPtr,
1808+
nullPtr};
1809+
1810+
rewriter.create<LLVM::CallOp>(loc, cudaRegisterFn.value(), args);
17801811
} else if (LLVM::GlobalOp g = dyn_cast<LLVM::GlobalOp>(op)) {
17811812
int addrSpace = g.getAddrSpace();
17821813
if (addrSpace != 1 /* device */ && addrSpace != 4 /* constant */)
@@ -1825,9 +1856,18 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
18251856
}
18261857
}
18271858
// TODO this has to happen only for some CUDA versions
1828-
if (gpuTarget == "cuda")
1829-
rtRegisterFatBinaryEndCallBuilder.create(loc, ctorBuilder,
1830-
{module.getResult()});
1859+
if (gpuTarget == "cuda") {
1860+
auto cudaRegisterFatbinFn = LLVM::lookupOrCreateFn(
1861+
rewriter, moduleOp, "__cudaRegisterFatBinaryEnd", llvmPointerType,
1862+
llvmVoidType);
1863+
if (failed(cudaRegisterFatbinFn)) {
1864+
llvm::errs() << " cudamalloc already exists with different types\n";
1865+
return failure();
1866+
}
1867+
1868+
rewriter.create<LLVM::CallOp>(loc, cudaRegisterFatbinFn.value(),
1869+
ValueRange(module->getResult(0)));
1870+
}
18311871
ctorBuilder.create<LLVM::ReturnOp>(loc, ValueRange());
18321872
}
18331873
auto ctorSymbol = FlatSymbolRefAttr::get(ctor);
@@ -1847,8 +1887,17 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
18471887
auto aoo = dtorBuilder.create<LLVM::AddressOfOp>(loc, moduleGlobal);
18481888
auto module = dtorBuilder.create<LLVM::LoadOp>(
18491889
loc, llvmPointerPointerType, aoo->getResult(0));
1850-
rtUnregisterFatBinaryCallBuilder.create(loc, dtorBuilder,
1851-
module.getResult());
1890+
1891+
auto cudaUnRegisterFatbinFn = LLVM::lookupOrCreateFn(
1892+
rewriter, moduleOp, "__cudaUnregisterFatBinary", llvmPointerType,
1893+
llvmVoidType);
1894+
if (failed(cudaUnRegisterFatbinFn)) {
1895+
llvm::errs() << " cudamalloc already exists with different types\n";
1896+
return failure();
1897+
}
1898+
1899+
rewriter.create<LLVM::CallOp>(loc, cudaUnRegisterFatbinFn.value(),
1900+
ValueRange(module));
18521901
dtorBuilder.create<LLVM::ReturnOp>(loc, ValueRange());
18531902
auto dtorSymbol = FlatSymbolRefAttr::get(dtor);
18541903
{
@@ -2469,6 +2518,34 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
24692518
}
24702519
};
24712520

2521+
/// Pattern for returning from a function, packs the results into a struct.
2522+
struct GPUReturnOpLowering : public ConvertOpToLLVMPattern<gpu::ReturnOp> {
2523+
public:
2524+
using ConvertOpToLLVMPattern<gpu::ReturnOp>::ConvertOpToLLVMPattern;
2525+
2526+
LogicalResult
2527+
matchAndRewrite(gpu::ReturnOp returnOp, OpAdaptor adaptor,
2528+
ConversionPatternRewriter &rewriter) const override {
2529+
if (returnOp->getNumOperands() <= 1) {
2530+
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp,
2531+
adaptor.getOperands());
2532+
return success();
2533+
}
2534+
2535+
auto returnedType = LLVM::LLVMStructType::getLiteral(
2536+
returnOp->getContext(),
2537+
llvm::to_vector(adaptor.getOperands().getTypes()));
2538+
Value packed =
2539+
rewriter.create<LLVM::UndefOp>(returnOp->getLoc(), returnedType);
2540+
for (const auto &[index, value] : llvm::enumerate(adaptor.getOperands())) {
2541+
packed = rewriter.create<LLVM::InsertValueOp>(returnOp->getLoc(), packed,
2542+
value, index);
2543+
}
2544+
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, packed);
2545+
return success();
2546+
}
2547+
};
2548+
24722549
/// TODO: Temporary until we migrate everything to opaque pointers
24732550
struct ReconcileUnrealizedPointerCasts
24742551
: public OpRewritePattern<UnrealizedConversionCastOp> {
@@ -2558,6 +2635,23 @@ populateCStyleMemRefLoweringPatterns(RewritePatternSet &patterns,
25582635
patterns.add<CMemcpyOpLowering>(typeConverter);
25592636
}
25602637

2638+
/// Appends the patterns lowering operations from the Func dialect to the LLVM
2639+
/// dialect using the C-style type conversion, i.e. converting memrefs to
2640+
/// pointer to arrays of arrays.
2641+
static void
2642+
populateCStyleGPUFuncLoweringPatterns(RewritePatternSet &patterns,
2643+
LLVMTypeConverter &typeConverter,
2644+
std::string gpuTarget) {
2645+
patterns.add<GPUReturnOpLowering>(typeConverter);
2646+
patterns.add<GPUFuncOpLowering>(
2647+
typeConverter,
2648+
/*allocaAddrSpace=*/0,
2649+
StringAttr::get(&typeConverter.getContext(),
2650+
gpuTarget == "cuda"
2651+
? NVVM::NVVMDialect::getKernelFuncAttrName()
2652+
: ROCDL::ROCDLDialect::getKernelFuncAttrName()));
2653+
}
2654+
25612655
/// Appends the patterns lowering operations from the Func dialect to the LLVM
25622656
/// dialect using the C-style type conversion, i.e. converting memrefs to
25632657
/// pointer to arrays of arrays.
@@ -2618,6 +2712,13 @@ struct ConvertPolygeistToLLVMPass
26182712

26192713
RewritePatternSet patterns(&getContext());
26202714

2715+
auto gpuTarget = "cuda";
2716+
2717+
// Insert our custom version of GPUFuncLowering
2718+
if (useCStyleMemRef) {
2719+
populateCStyleGPUFuncLoweringPatterns(patterns, converter, gpuTarget);
2720+
}
2721+
26212722
populatePolygeistToLLVMConversionPatterns(converter, patterns);
26222723
populateSCFToControlFlowConversionPatterns(patterns);
26232724
// populateForBreakToWhilePatterns(patterns);
@@ -2642,7 +2743,6 @@ struct ConvertPolygeistToLLVMPass
26422743

26432744
// Our custom versions of the gpu patterns
26442745
if (useCStyleMemRef) {
2645-
auto gpuTarget = "cuda";
26462746
patterns.add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(
26472747
converter, "gpu.binary", gpuTarget);
26482748
// patterns.add<LegalizeLaunchFuncOpPattern>(

src/enzyme_ad/jax/Passes/GPULaunchRecognition.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ struct GPULaunchRecognitionPass
9191
cop.getArgOperands()[0].getDefiningOp<LLVM::AddressOfOp>();
9292
if (!argop)
9393
continue;
94-
llvm::errs() << "argop: " << argop << "\n";
9594
auto cur = argop.getFunction(symbolTable);
9695
if (!cur)
9796
continue;
@@ -156,14 +155,12 @@ struct GPULaunchRecognitionPass
156155
});
157156
}
158157

159-
auto loc = launchFunc->getLoc();
158+
auto loc = cop->getLoc();
160159
builder.setInsertionPointAfter(cop);
161160

162161
auto shMemSize = builder.create<LLVM::TruncOp>(
163162
loc, builder.getI32Type(), cop.getArgOperands()[7]);
164163
auto stream = cop.getArgOperands()[8];
165-
llvm::errs() << " stream: " << stream << "\n";
166-
// TODO stream is arg 8
167164
llvm::SmallVector<mlir::Value> args;
168165
for (unsigned i = 9; i < cop.getArgOperands().size(); i++)
169166
args.push_back(cop.getArgOperands()[i]);
@@ -194,8 +191,8 @@ struct GPULaunchRecognitionPass
194191
ValueRange(args));
195192
} else {
196193
auto op = builder.create<mlir::gpu::LaunchOp>(
197-
loc, grid[0], grid[1], grid[2], block[0], block[1], block[2],
198-
shMemSize, nullptr, ValueRange());
194+
launchFunc->getLoc(), grid[0], grid[1], grid[2], block[0],
195+
block[1], block[2], shMemSize, nullptr, ValueRange());
199196
builder.setInsertionPointToStart(&op.getRegion().front());
200197
builder.create<LLVM::CallOp>(loc, cur, args);
201198
builder.create<gpu::TerminatorOp>(loc);
@@ -208,8 +205,9 @@ struct GPULaunchRecognitionPass
208205
ValueRange(args), stream.getType(), ValueRange(stream));
209206
} else {
210207
auto op = builder.create<mlir::gpu::LaunchOp>(
211-
loc, grid[0], grid[1], grid[2], block[0], block[1], block[2],
212-
shMemSize, stream.getType(), ValueRange(stream));
208+
launchFunc->getLoc(), grid[0], grid[1], grid[2], block[0],
209+
block[1], block[2], shMemSize, stream.getType(),
210+
ValueRange(stream));
213211
builder.setInsertionPointToStart(&op.getRegion().front());
214212
builder.create<LLVM::CallOp>(loc, cur, args);
215213
builder.create<gpu::TerminatorOp>(loc);

src/enzyme_ad/jax/Passes/ParallelLower.cpp

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ namespace mlir {
4040
namespace enzyme {
4141
#define GEN_PASS_DEF_PARALLELLOWER
4242
#define GEN_PASS_DEF_FIXGPUFUNC
43+
#define GEN_PASS_DEF_STRIPGPUINFO
4344
#include "src/enzyme_ad/jax/Passes/Passes.h.inc"
4445
} // namespace enzyme
4546
} // namespace mlir
@@ -113,6 +114,10 @@ struct FixGPUFunc : public enzyme::impl::FixGPUFuncBase<FixGPUFunc> {
113114
using FixGPUFuncBase::FixGPUFuncBase;
114115
void runOnOperation() override;
115116
};
117+
struct StripGPUInfo : public enzyme::impl::StripGPUInfoBase<StripGPUInfo> {
118+
using StripGPUInfoBase::StripGPUInfoBase;
119+
void runOnOperation() override;
120+
};
116121
} // end anonymous namespace
117122

118123
/// Creates a pass to perform optimizations relying on memref dataflow such as
@@ -412,11 +417,16 @@ void ParallelLower::runOnOperation() {
412417
for (auto op : ops)
413418
callInliner(op);
414419
}
420+
LLVM::LLVMFuncOp lfn = nullptr;
415421
{
416422
SmallVector<LLVM::CallOp> lops;
417423
launchOp.walk([&](LLVM::CallOp caller) { lops.push_back(caller); });
418-
for (auto op : lops)
424+
for (auto op : lops) {
425+
if (!lfn)
426+
lfn = dyn_cast_or_null<LLVM::LLVMFuncOp>(
427+
op.resolveCallableInTable(&symbolTable));
419428
LLVMcallInliner(op);
429+
}
420430
}
421431

422432
mlir::IRRewriter builder(launchOp.getContext());
@@ -449,6 +459,14 @@ void ParallelLower::runOnOperation() {
449459
ValueRange({launchOp.getGridSizeX(), launchOp.getGridSizeY(),
450460
launchOp.getGridSizeZ(), launchOp.getBlockSizeX(),
451461
launchOp.getBlockSizeY(), launchOp.getBlockSizeZ()}));
462+
if (lfn) {
463+
if (auto passthrough = lfn.getPassthrough()) {
464+
pw->setAttr("passthrough", *passthrough);
465+
}
466+
if (auto passthrough = lfn.getTargetFeatures()) {
467+
pw->setAttr("target_features", *passthrough);
468+
}
469+
}
452470
builder.setInsertionPointToStart(pw.getBody());
453471
}
454472

@@ -893,6 +911,22 @@ void ConvertCudaRTtoCPU::runOnOperation() {
893911
}
894912
#endif
895913

914+
void StripGPUInfo::runOnOperation() {
915+
getOperation()->walk([](gpu::GPUModuleOp v) {
916+
auto unknown = OpBuilder(v).getUnknownLoc();
917+
v->walk([&](Operation *op) {
918+
op->setLoc(unknown);
919+
for (auto &region : op->getRegions()) {
920+
for (auto &blk : region) {
921+
for (auto &arg : blk.getArguments()) {
922+
arg.setLoc(unknown);
923+
}
924+
}
925+
}
926+
});
927+
});
928+
}
929+
896930
// Returns a list of all symbols provided by cudart (obtained from
897931
// libcudart_static.a)
898932
static std::vector<llvm::StringRef> getCudartSymbols();

src/enzyme_ad/jax/Passes/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -787,4 +787,9 @@ def FixGPUFunc : Pass<"fix-gpu-func", "mlir::gpu::GPUModuleOp"> {
787787
let dependentDialects = ["func::FuncDialect", "LLVM::LLVMDialect", "gpu::GPUDialect"];
788788
}
789789

790+
def StripGPUInfo : Pass<"strip-gpu-info"> {
791+
let summary = "Stirng GPU Debug info";
792+
let dependentDialects = ["gpu::GPUDialect"];
793+
}
794+
790795
#endif

0 commit comments

Comments
 (0)