diff --git a/src/enzyme_ad/jax/Dialect/TritonExt/Ops.td b/src/enzyme_ad/jax/Dialect/TritonExt/Ops.td index 22aef1f07..a13abd53e 100644 --- a/src/enzyme_ad/jax/Dialect/TritonExt/Ops.td +++ b/src/enzyme_ad/jax/Dialect/TritonExt/Ops.td @@ -16,6 +16,8 @@ def TensorI64 "tensor", "::mlir::TensorType">, BuildableType<"RankedTensorType::get({}, $_builder.getIntegerType(64))">; +def ScratchTensor : RankedTensorOf<[I8]>; + def TritonModuleOp : TritonExtOp<"module", [ IsolatedFromAbove, NoRegionArguments, SymbolTable, Symbol, NoTerminator ]> { @@ -32,6 +34,21 @@ def TritonModuleOp : TritonExtOp<"module", [ // clang-format on } +def ScratchMemoryOp : TritonExtOp<"scratch_memory", [Pure]> { + let summary = "Allocate scratch memory"; + let description = [{ Allocate scratch memory for a kernel. }]; + + let arguments = (ins I64Attr : $alignment); + + let results = (outs ScratchTensor : $result); + + // clang-format off + let assemblyFormat = [{ + attr-dict `:` type($result) + }]; + // clang-format on +} + def TritonCallOp : TritonExtOp<"call", [ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, diff --git a/src/enzyme_ad/jax/Passes/ConvertTritonToTritonGPUPreservingModuleAttributes.cpp b/src/enzyme_ad/jax/Passes/ConvertTritonToTritonGPUPreservingModuleAttributes.cpp new file mode 100644 index 000000000..4cd12c681 --- /dev/null +++ b/src/enzyme_ad/jax/Passes/ConvertTritonToTritonGPUPreservingModuleAttributes.cpp @@ -0,0 +1,64 @@ +#include "src/enzyme_ad/jax/Passes/Passes.h" + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" + +#include "triton/Conversion/TritonToTritonGPU/Passes.h" + +#define DEBUG_TYPE "convert-triton-to-triton-gpu-preserving-module-attributes" + +namespace mlir { +namespace enzyme { +#define GEN_PASS_DEF_CONVERTTRITONTOTRITONGPUPRESERVINGMODULEATTRIBUTESPASS +#include "src/enzyme_ad/jax/Passes/Passes.h.inc" +} // namespace enzyme +} // namespace mlir + +using namespace mlir; +using namespace mlir::enzyme; + +struct ConvertTritonToTritonGPUPreservingModuleAttributesPass + : public mlir::enzyme::impl:: + ConvertTritonToTritonGPUPreservingModuleAttributesPassBase< + ConvertTritonToTritonGPUPreservingModuleAttributesPass> { + using Base::Base; + + void runOnOperation() override { + ModuleOp mod = getOperation(); + + int32_t numWarps = 4, threadsPerWarp = 32, numCtas = 1; + bool enableSourceRemat = false; + + if (mod->hasAttr("enzymexla.ttg.num-ctas")) { + numCtas = + mod->getAttrOfType("enzymexla.ttg.num-ctas").getInt(); + } + + if (mod->hasAttr("enzymexla.ttg.num-warps")) { + numWarps = + mod->getAttrOfType("enzymexla.ttg.num-warps").getInt(); + } + + if (mod->hasAttr("enzymexla.ttg.threads-per-warp")) { + threadsPerWarp = + mod->getAttrOfType("enzymexla.ttg.threads-per-warp") + .getInt(); + } + + if (mod->hasAttr("enzymexla.ttg.enable-source-remat")) { + enableSourceRemat = true; + } + + OpPassManager pm; + pm.addPass(triton::createConvertTritonToTritonGPU( + {target, numWarps, threadsPerWarp, numCtas, enableSourceRemat})); + if (failed(runPipeline(pm, mod))) { + mod->emitError() << "failed to run triton passes"; + signalPassFailure(); + return; + } + + return; + } +}; diff --git a/src/enzyme_ad/jax/Passes/LowerTriton.cpp b/src/enzyme_ad/jax/Passes/LowerTriton.cpp new file mode 100644 index 000000000..94e2f1e9d --- /dev/null +++ b/src/enzyme_ad/jax/Passes/LowerTriton.cpp @@ -0,0 +1,98 @@ +#include "src/enzyme_ad/jax/Passes/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "src/enzyme_ad/jax/Dialect/Dialect.h" +#include "src/enzyme_ad/jax/Dialect/Ops.h" +#include "src/enzyme_ad/jax/Dialect/TritonExt/Dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "src/enzyme_ad/jax/Utils.h" + +#include "llvm/ADT/SmallVector.h" + +#define DEBUG_TYPE "lower-triton" + +namespace mlir { +namespace enzyme { +#define GEN_PASS_DEF_LOWERTRITONPASS +#include "src/enzyme_ad/jax/Passes/Passes.h.inc" +} // namespace enzyme +} // namespace mlir + +using namespace mlir; +using namespace mlir::enzyme; +using namespace mlir::enzymexla; +using namespace mlir::enzymexla::triton_ext; + +LogicalResult lowerTritonKernelToKernelCall(ModuleOp mod, + triton_ext::TritonCallOp op) { + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(mod); + auto funcOp = symbolTable.lookupNearestSymbolFrom(op, op.getFnAttr()); + if (!funcOp) { + op->emitError("Failed to find function '") << op.getFn() << "' in module"; + return failure(); + } + + auto wrappedMod = funcOp->getParentOfType(); + if (!wrappedMod) { + op->emitError("Failed to find parent built-in module."); + return failure(); + } + + if (!wrappedMod->hasAttr("ttg.shared")) { + op->emitError("No ttg.shared attribute found. Triton Passes must be run " + "before invoking lower-triton pass."); + return failure(); + } + + auto ttModOP = wrappedMod->getParentOfType(); + if (!ttModOP) { + op->emitError("No `triton_ext.module` found!"); + return failure(); + } + ttModOP.setVisibility(SymbolTable::Visibility::Private); + + OpBuilder builder(op); + + auto sharedMemSizeAttr = wrappedMod->getAttrOfType("ttg.shared"); + auto sharedMemSize = sharedMemSizeAttr.getValue().getZExtValue(); + auto shmemOpType = op.getGridx().getType(); + auto shmemOp = stablehlo::ConstantOp::create( + builder, op.getLoc(), shmemOpType, + cast(makeAttr(shmemOpType, sharedMemSize))); + + auto kernelCallOp = enzymexla::KernelCallOp::create( + builder, op.getLoc(), op.getResultTypes(), op.getFn(), op.getGridx(), + op.getGridy(), op.getGridz(), op.getBlockx(), op.getBlocky(), + op.getBlockz(), shmemOp, op.getClusterx(), op.getClustery(), + op.getClusterz(), op.getInputs(), op.getBackendConfigAttr(), + op.getOperandLayoutsAttr(), op.getResultLayoutsAttr(), + op.getArgAttrsAttr(), op.getResAttrsAttr(), + op.getOutputOperandAliasesAttr(), op.getXlaSideEffectFreeAttr()); + op.replaceAllUsesWith(kernelCallOp); + op.erase(); + return success(); +} + +struct LowerTritonPass + : public mlir::enzyme::impl::LowerTritonPassBase { + using Base::Base; + + void runOnOperation() override { + auto modOp = getOperation(); + + modOp->walk([&](triton_ext::TritonCallOp op) { + if (failed(lowerTritonKernelToKernelCall(modOp, op))) { + signalPassFailure(); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + } +}; diff --git a/src/enzyme_ad/jax/Passes/LowerTritonExtensionOps.cpp b/src/enzyme_ad/jax/Passes/LowerTritonExtensionOps.cpp new file mode 100644 index 000000000..7882e12d6 --- /dev/null +++ b/src/enzyme_ad/jax/Passes/LowerTritonExtensionOps.cpp @@ -0,0 +1,151 @@ +#include "src/enzyme_ad/jax/Passes/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "src/enzyme_ad/jax/Dialect/Dialect.h" +#include "src/enzyme_ad/jax/Dialect/Ops.h" +#include "src/enzyme_ad/jax/Dialect/TritonExt/Dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "src/enzyme_ad/jax/Utils.h" + +#include "llvm/ADT/SmallVector.h" + +#define DEBUG_TYPE "lower-triton-extension-ops" + +namespace mlir { +namespace enzyme { +#define GEN_PASS_DEF_LOWERTRITONEXTENSIONOPSPASS +#include "src/enzyme_ad/jax/Passes/Passes.h.inc" +} // namespace enzyme +} // namespace mlir + +using namespace mlir; +using namespace mlir::enzyme; +using namespace mlir::enzymexla; +using namespace mlir::enzymexla::triton_ext; + +struct JITCallScratchMemoryLowering + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(enzymexla::JITCallOp op, + PatternRewriter &rewriter) const override { + auto inputs = op.getInputs(); + + BitVector rewriteScratchMemoryIdxs(inputs.size(), false); + SmallVector newInputs; + bool hasScratchMemory = false; + for (size_t i = 0; i < inputs.size(); i++) { + if (auto scratchMemoryOp = + inputs[i].getDefiningOp()) { + hasScratchMemory = true; + rewriteScratchMemoryIdxs.set(i); + continue; + } + newInputs.push_back(inputs[i]); + } + + if (!hasScratchMemory) + return failure(); // nothing to do + + // hoist the scratch memory allocation and use gpu.alloc to allocate this + // memory in the jit call function + auto modOp = op->getParentOfType(); + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(modOp); + auto funcOp = symbolTable.lookupNearestSymbolFrom(op, op.getFnAttr()); + if (!funcOp) { + op->emitError("Failed to find function '") << op.getFn() << "' in module"; + return failure(); + } + + auto funcOpInterface = dyn_cast(funcOp); + + auto &fnBody = funcOp->getRegion(0).front(); + + for (unsigned idx : rewriteScratchMemoryIdxs.set_bits()) { + rewriter.setInsertionPoint(&fnBody, fnBody.begin()); + auto scratchMemoryOp = + inputs[idx].getDefiningOp(); + auto outTy = + cast(scratchMemoryOp.getResult().getType()); + assert(outTy.getRank() == 1); + + auto outMemrefType = MemRefType::get( + outTy.getShape(), outTy.getElementType(), MemRefLayoutAttrInterface{}, + rewriter.getI64IntegerAttr( + cast(fnBody.getArgument(idx).getType()) + .getAddressSpace())); + + auto allocOp = + memref::AllocOp::create(rewriter, op.getLoc(), outMemrefType, + scratchMemoryOp.getAlignmentAttr()); + auto ptrOp = enzymexla::Memref2PointerOp::create( + rewriter, op.getLoc(), + LLVM::LLVMPointerType::get(rewriter.getContext(), + outMemrefType.getMemorySpaceAsInt()), + allocOp.getResult()); + rewriter.replaceAllUsesWith(fnBody.getArgument(idx), ptrOp.getResult()); + + // clang-format off + // FIXME: This is producing + // error: 'llvm.call' op operand type mismatch for operand 0: '!llvm.ptr<1>' != '!llvm.ptr' + // see current operation: "llvm.call"(%61, %60) <{CConv = #llvm.cconv, TailCallKind = #llvm.tailcallkind, callee = @mgpuMemFree, fastmathFlags = #llvm.fastmath, op_bundle_sizes = array, operandSegmentSizes = array}> : (!llvm.ptr<1>, !llvm.ptr) -> () + // SmallVector deps; + // Operation *lastUser = ptrOp; + // for (auto u : ptrOp->getUsers()) { + // if (auto gpuLaunchOp = dyn_cast(u)) { + // deps.push_back(gpuLaunchOp.getAsyncToken()); + // } + + // if (lastUser->isBeforeInBlock(u)) { + // lastUser = u; + // } + // } + + // rewriter.setInsertionPointAfter(lastUser); + // gpu::DeallocOp::create(rewriter, op.getLoc(), + // gpu::AsyncTokenType::get(rewriter.getContext()), + // ValueRange(deps), allocOp.getResult()); + // clang-format on + } + + funcOpInterface.eraseArguments(rewriteScratchMemoryIdxs); + + // TODO: to be safe we should rework the other attributes if they are being + // removed... + rewriter.setInsertionPoint(op); + auto newJitCallOp = enzymexla::JITCallOp::create( + rewriter, op.getLoc(), op.getResultTypes(), op.getFn(), newInputs, + op.getBackendConfigAttr(), op.getOperandLayoutsAttr(), + op.getResultLayoutsAttr(), op.getArgAttrsAttr(), op.getResAttrsAttr(), + op.getOutputOperandAliasesAttr(), op.getXlaSideEffectFreeAttr()); + rewriter.replaceOp(op, newJitCallOp); + return success(); + } +}; + +struct LowerTritonExtensionOpsPass + : public mlir::enzyme::impl::LowerTritonExtensionOpsPassBase< + LowerTritonExtensionOpsPass> { + using Base::Base; + + void runOnOperation() override { + auto context = getOperation()->getContext(); + + RewritePatternSet patterns(context); + patterns.add(context); + + GreedyRewriteConfig config; + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config))) { + signalPassFailure(); + } + } +}; diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 06c2259c6..22cd1b422 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -1008,4 +1008,52 @@ def SCFCPUify : Pass<"cpuify"> { Option<"method", "method", "std::string", /*default=*/"\"distribute\"", "Method of doing distribution"> ]; } + +def ConvertTritonToTritonGPUPreservingModuleAttributesPass : Pass< + "convert-triton-to-triton-gpu-preserving-module-attributes", "mlir::ModuleOp"> { + let summary = "Triton generally compiles a single kernel, so they can specify the number of ctas and warps. However, we want to be able to compile multiple kernels. This pass will use the attributes from the module and use that to lower to TritonGPU."; + let dependentDialects = []; + let options = [ + Option< + /*C++ variable name=*/"target", + /*CLI argument=*/"target", + /*type=*/"std::string", + /*default=*/"\"\"", + /*description=*/"the GPU target, e.g., cuda:80, hip:gfx942" + >]; +} + +def LowerTritonPass : Pass<"lower-triton", "mlir::ModuleOp"> { + let summary = "Lower Triton to kernel call"; + let dependentDialects = [ + "triton::TritonDialect", + "gpu::GPUDialect", + "enzymexla::EnzymeXLADialect", + "func::FuncDialect", + "enzymexla::triton_ext::TritonExtDialect", + "stablehlo::StablehloDialect", + ]; +} + +def TritonAugmentFunctionWithExtraArgumentsPass : Pass< + "triton-augment-function-with-extra-arguments", "mlir::ModuleOp"> { + let dependentDialects = [ + "triton::TritonDialect", + "func::FuncDialect", + "enzymexla::triton_ext::TritonExtDialect", + ]; +} + +def LowerTritonExtensionOpsPass : Pass<"lower-triton-extension-ops"> { + let dependentDialects = [ + "triton::TritonDialect", + "func::FuncDialect", + "LLVM::LLVMDialect", + "memref::MemRefDialect", + "enzymexla::EnzymeXLADialect", + "enzymexla::triton_ext::TritonExtDialect", + "gpu::GPUDialect", + ]; +} + #endif diff --git a/src/enzyme_ad/jax/Passes/TritonAugmentFunctionWithExtraArguments.cpp b/src/enzyme_ad/jax/Passes/TritonAugmentFunctionWithExtraArguments.cpp new file mode 100644 index 000000000..0fb2691d5 --- /dev/null +++ b/src/enzyme_ad/jax/Passes/TritonAugmentFunctionWithExtraArguments.cpp @@ -0,0 +1,145 @@ +#include "src/enzyme_ad/jax/Passes/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "src/enzyme_ad/jax/Dialect/TritonExt/Dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "src/enzyme_ad/jax/Utils.h" + +#include "llvm/ADT/SmallVector.h" + +#define DEBUG_TYPE "triton-augment-function-with-extra-arguments" + +namespace mlir { +namespace enzyme { +#define GEN_PASS_DEF_TRITONAUGMENTFUNCTIONWITHEXTRAARGUMENTSPASS +#include "src/enzyme_ad/jax/Passes/Passes.h.inc" +} // namespace enzyme +} // namespace mlir + +using namespace mlir; +using namespace mlir::enzyme; +using namespace mlir::enzymexla; +using namespace mlir::enzymexla::triton_ext; + +// See for description on the extra arguments +// https://github.com/triton-lang/triton/blob/6ac622c57152ce88edd058f11997b5c5e18d096b/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp#L12-L25 + +LogicalResult +augmentTritonCallOpWithExtraArguments(ModuleOp mod, + triton_ext::TritonCallOp op) { + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(mod); + auto funcOp = symbolTable.lookupNearestSymbolFrom( + mod, op.getFnAttr()); + if (!funcOp) { + op->emitError("Failed to find function '") << op.getFn() << "' in module"; + return failure(); + } + + auto fnKind = funcOp->getName().getStringRef(); + if (fnKind != "llvm.func") { + op->emitError("augmentTritonCallOpWithExtraArguments: expected '") + << op.getFn() << "' to be a llvm.func, got: " << fnKind << ". This " + << "means that the pass is being called before tt.func is being " + "lowered to llvm.func"; + return failure(); + } + + if (funcOp.getNumArguments() == op.getInputs().size()) { + return success(); // already augmented + } + + // See NOTE: [Additional Function Arguments] in triton-lang/triton + if (!mlir::triton::isKernel(funcOp)) { + op->emitError("not a kernel function"); + return failure(); + } + + bool hasProfileScratchMemory = + funcOp.getNumArguments() == + op.getInputs().size() + 2; // to support compatibility with old kernels + + if (funcOp.getNumArguments() != + op.getInputs().size() + 1 + hasProfileScratchMemory) { + op->emitError("Expected ") + << (funcOp.getNumArguments() - 1 - hasProfileScratchMemory) + << " arguments, got " << op.getInputs().size(); + return failure(); + } + + auto newInputs = llvm::to_vector(op.getInputs()); + + // global scratch memory + uint64_t gsmNBytes = 0, gsmAlign = 0; + if (auto gsm = funcOp->getAttrOfType( + "ttg.global_scratch_memory_size")) { + gsmNBytes = gsm.getValue().getZExtValue(); + } + if (auto smalign = funcOp->getAttrOfType( + "ttg.global_scratch_memory_alignment")) { + gsmAlign = smalign.getValue().getZExtValue(); + } + + OpBuilder builder(op); + + auto gsmTy = RankedTensorType::get({static_cast(gsmNBytes)}, + builder.getIntegerType(8)); + auto gsm = triton_ext::ScratchMemoryOp::create( + builder, op.getLoc(), gsmTy, builder.getI64IntegerAttr(gsmAlign)); + newInputs.push_back(gsm); + + // profile scratch memory + if (hasProfileScratchMemory) { + uint64_t psmNBytes = 0, psmAlign = 1; + if (auto psm = funcOp->getAttrOfType( + "ttg.profile_scratch_memory_size")) { + psmNBytes = psm.getValue().getZExtValue(); + } + if (auto psmalign = funcOp->getAttrOfType( + "ttg.profile_scratch_memory_alignment")) { + psmAlign = psmalign.getValue().getZExtValue(); + } + + auto psmTy = RankedTensorType::get({static_cast(psmNBytes)}, + builder.getIntegerType(8)); + auto psm = triton_ext::ScratchMemoryOp::create( + builder, op.getLoc(), psmTy, builder.getI64IntegerAttr(psmAlign)); + newInputs.push_back(psm); + } + + auto newCallOp = triton_ext::TritonCallOp::create( + builder, op.getLoc(), op.getResultTypes(), op.getFn(), op.getGridx(), + op.getGridy(), op.getGridz(), op.getBlockx(), op.getBlocky(), + op.getBlockz(), op.getClusterx(), op.getClustery(), op.getClusterz(), + newInputs, op.getBackendConfigAttr(), op.getOperandLayoutsAttr(), + /*resultLayouts*/ nullptr, op.getArgAttrsAttr(), op.getResAttrsAttr(), + op.getOutputOperandAliasesAttr(), op.getXlaSideEffectFreeAttr()); + op.replaceAllUsesWith(newCallOp); + op.erase(); + return success(); +} + +struct TritonAugmentFunctionWithExtraArgumentsPass + : public mlir::enzyme::impl:: + TritonAugmentFunctionWithExtraArgumentsPassBase< + TritonAugmentFunctionWithExtraArgumentsPass> { + using Base::Base; + + void runOnOperation() override { + auto modOp = getOperation(); + + modOp->walk([&](triton_ext::TritonCallOp op) -> WalkResult { + if (failed(augmentTritonCallOpWithExtraArguments(modOp, op))) { + signalPassFailure(); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + } +};