Skip to content
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/enzyme_ad/jax/Dialect/TritonExt/Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def TensorI64
"tensor<i64>", "::mlir::TensorType">,
BuildableType<"RankedTensorType::get({}, $_builder.getIntegerType(64))">;

def ScratchTensor : RankedTensorOf<[I8]>;

def TritonModuleOp : TritonExtOp<"module", [
IsolatedFromAbove, NoRegionArguments, SymbolTable, Symbol, NoTerminator
]> {
Expand All @@ -32,6 +34,21 @@ def TritonModuleOp : TritonExtOp<"module", [
// clang-format on
}

def ScratchMemoryOp : TritonExtOp<"scratch_memory", [Pure]> {
let summary = "Allocate scratch memory";
let description = [{ Allocate scratch memory for a kernel. }];

let arguments = (ins I64Attr : $alignment);

let results = (outs ScratchTensor : $result);

// clang-format off
let assemblyFormat = [{
attr-dict `:` type($result)
}];
// clang-format on
}

def TritonCallOp : TritonExtOp<"call", [
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
DeclareOpInterfaceMethods<CallOpInterface>,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#include "src/enzyme_ad/jax/Passes/Passes.h"

#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"

#include "triton/Conversion/TritonToTritonGPU/Passes.h"

#define DEBUG_TYPE "convert-triton-to-triton-gpu-preserving-module-attributes"

namespace mlir {
namespace enzyme {
#define GEN_PASS_DEF_CONVERTTRITONTOTRITONGPUPRESERVINGMODULEATTRIBUTESPASS
#include "src/enzyme_ad/jax/Passes/Passes.h.inc"
} // namespace enzyme
} // namespace mlir

using namespace mlir;
using namespace mlir::enzyme;

struct ConvertTritonToTritonGPUPreservingModuleAttributesPass
: public mlir::enzyme::impl::
ConvertTritonToTritonGPUPreservingModuleAttributesPassBase<
ConvertTritonToTritonGPUPreservingModuleAttributesPass> {
using Base::Base;

void runOnOperation() override {
ModuleOp mod = getOperation();

int32_t numWarps = 4, threadsPerWarp = 32, numCtas = 1;
bool enableSourceRemat = false;

if (mod->hasAttr("enzymexla.ttg.num-ctas")) {
numCtas =
mod->getAttrOfType<IntegerAttr>("enzymexla.ttg.num-ctas").getInt();
}

if (mod->hasAttr("enzymexla.ttg.num-warps")) {
numWarps =
mod->getAttrOfType<IntegerAttr>("enzymexla.ttg.num-warps").getInt();
}

if (mod->hasAttr("enzymexla.ttg.threads-per-warp")) {
threadsPerWarp =
mod->getAttrOfType<IntegerAttr>("enzymexla.ttg.threads-per-warp")
.getInt();
}

if (mod->hasAttr("enzymexla.ttg.enable-source-remat")) {
enableSourceRemat = true;
}

OpPassManager pm;
pm.addPass(triton::createConvertTritonToTritonGPU(
{target, numWarps, threadsPerWarp, numCtas, enableSourceRemat}));
if (failed(runPipeline(pm, mod))) {
mod->emitError() << "failed to run triton passes";
signalPassFailure();
return;
}

return;
}
};
98 changes: 98 additions & 0 deletions src/enzyme_ad/jax/Passes/LowerTriton.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#include "src/enzyme_ad/jax/Passes/Passes.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "src/enzyme_ad/jax/Dialect/Dialect.h"
#include "src/enzyme_ad/jax/Dialect/Ops.h"
#include "src/enzyme_ad/jax/Dialect/TritonExt/Dialect.h"
#include "triton/Dialect/Triton/IR/Dialect.h"

#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "src/enzyme_ad/jax/Utils.h"

#include "llvm/ADT/SmallVector.h"

#define DEBUG_TYPE "lower-triton"

namespace mlir {
namespace enzyme {
#define GEN_PASS_DEF_LOWERTRITONPASS
#include "src/enzyme_ad/jax/Passes/Passes.h.inc"
} // namespace enzyme
} // namespace mlir

using namespace mlir;
using namespace mlir::enzyme;
using namespace mlir::enzymexla;
using namespace mlir::enzymexla::triton_ext;

LogicalResult lowerTritonKernelToKernelCall(ModuleOp mod,
triton_ext::TritonCallOp op) {
SymbolTableCollection symbolTable;
symbolTable.getSymbolTable(mod);
auto funcOp = symbolTable.lookupNearestSymbolFrom(op, op.getFnAttr());
if (!funcOp) {
op->emitError("Failed to find function '") << op.getFn() << "' in module";
return failure();
}

auto wrappedMod = funcOp->getParentOfType<ModuleOp>();
if (!wrappedMod) {
op->emitError("Failed to find parent built-in module.");
return failure();
}

if (!wrappedMod->hasAttr("ttg.shared")) {
op->emitError("No ttg.shared attribute found. Triton Passes must be run "
"before invoking lower-triton pass.");
return failure();
}

auto ttModOP = wrappedMod->getParentOfType<triton_ext::TritonModuleOp>();
if (!ttModOP) {
op->emitError("No `triton_ext.module` found!");
return failure();
}
ttModOP.setVisibility(SymbolTable::Visibility::Private);

OpBuilder builder(op);

auto sharedMemSizeAttr = wrappedMod->getAttrOfType<IntegerAttr>("ttg.shared");
auto sharedMemSize = sharedMemSizeAttr.getValue().getZExtValue();
auto shmemOpType = op.getGridx().getType();
auto shmemOp = stablehlo::ConstantOp::create(
builder, op.getLoc(), shmemOpType,
cast<ElementsAttr>(makeAttr(shmemOpType, sharedMemSize)));

auto kernelCallOp = enzymexla::KernelCallOp::create(
builder, op.getLoc(), op.getResultTypes(), op.getFn(), op.getGridx(),
op.getGridy(), op.getGridz(), op.getBlockx(), op.getBlocky(),
op.getBlockz(), shmemOp, op.getClusterx(), op.getClustery(),
op.getClusterz(), op.getInputs(), op.getBackendConfigAttr(),
op.getOperandLayoutsAttr(), op.getResultLayoutsAttr(),
op.getArgAttrsAttr(), op.getResAttrsAttr(),
op.getOutputOperandAliasesAttr(), op.getXlaSideEffectFreeAttr());
op.replaceAllUsesWith(kernelCallOp);
op.erase();
return success();
}

struct LowerTritonPass
: public mlir::enzyme::impl::LowerTritonPassBase<LowerTritonPass> {
using Base::Base;

void runOnOperation() override {
auto modOp = getOperation();

modOp->walk([&](triton_ext::TritonCallOp op) {
if (failed(lowerTritonKernelToKernelCall(modOp, op))) {
signalPassFailure();
return WalkResult::interrupt();
}
return WalkResult::advance();
});
}
};
146 changes: 146 additions & 0 deletions src/enzyme_ad/jax/Passes/LowerTritonExtensionOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#include "src/enzyme_ad/jax/Passes/Passes.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "src/enzyme_ad/jax/Dialect/Dialect.h"
#include "src/enzyme_ad/jax/Dialect/Ops.h"
#include "src/enzyme_ad/jax/Dialect/TritonExt/Dialect.h"
#include "triton/Dialect/Triton/IR/Dialect.h"

#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "src/enzyme_ad/jax/Utils.h"

#include "llvm/ADT/SmallVector.h"

#define DEBUG_TYPE "lower-triton-extension-ops"

namespace mlir {
namespace enzyme {
#define GEN_PASS_DEF_LOWERTRITONEXTENSIONOPSPASS
#include "src/enzyme_ad/jax/Passes/Passes.h.inc"
} // namespace enzyme
} // namespace mlir

using namespace mlir;
using namespace mlir::enzyme;
using namespace mlir::enzymexla;
using namespace mlir::enzymexla::triton_ext;

struct JITCallScratchMemoryLowering
: public OpRewritePattern<enzymexla::JITCallOp> {
using OpRewritePattern<enzymexla::JITCallOp>::OpRewritePattern;

LogicalResult matchAndRewrite(enzymexla::JITCallOp op,
PatternRewriter &rewriter) const override {
auto inputs = op.getInputs();

BitVector rewriteScratchMemoryIdxs(inputs.size(), false);
SmallVector<Value> newInputs;
bool hasScratchMemory = false;
for (size_t i = 0; i < inputs.size(); i++) {
if (auto scratchMemoryOp =
inputs[i].getDefiningOp<triton_ext::ScratchMemoryOp>()) {
hasScratchMemory = true;
rewriteScratchMemoryIdxs.set(i);
continue;
}
newInputs.push_back(inputs[i]);
}

if (!hasScratchMemory)
return failure(); // nothing to do

// hoist the scratch memory allocation and use gpu.alloc to allocate this
// memory in the jit call function
auto modOp = op->getParentOfType<ModuleOp>();
SymbolTableCollection symbolTable;
symbolTable.getSymbolTable(modOp);
auto funcOp = symbolTable.lookupNearestSymbolFrom(op, op.getFnAttr());
if (!funcOp) {
op->emitError("Failed to find function '") << op.getFn() << "' in module";
return failure();
}

auto funcOpInterface = dyn_cast<FunctionOpInterface>(funcOp);

auto &fnBody = funcOp->getRegion(0).front();

for (unsigned idx : rewriteScratchMemoryIdxs.set_bits()) {
rewriter.setInsertionPoint(&fnBody, fnBody.begin());
auto scratchMemoryOp =
inputs[idx].getDefiningOp<triton_ext::ScratchMemoryOp>();
auto outTy =
cast<RankedTensorType>(scratchMemoryOp.getResult().getType());
assert(outTy.getRank() == 1);

auto outMemrefType = MemRefType::get(
outTy.getShape(), outTy.getElementType(), MemRefLayoutAttrInterface{},
rewriter.getI64IntegerAttr(
cast<LLVM::LLVMPointerType>(fnBody.getArgument(idx).getType())
.getAddressSpace()));

auto allocOp =
memref::AllocOp::create(rewriter, op.getLoc(), outMemrefType,
scratchMemoryOp.getAlignmentAttr());
auto ptrOp = enzymexla::Memref2PointerOp::create(
rewriter, op.getLoc(),
LLVM::LLVMPointerType::get(rewriter.getContext(),
outMemrefType.getMemorySpaceAsInt()),
allocOp.getResult());
rewriter.replaceAllUsesWith(fnBody.getArgument(idx), ptrOp.getResult());

SmallVector<Value> deps;
Operation *lastUser = ptrOp;
for (auto u : ptrOp->getUsers()) {
if (auto gpuLaunchOp = dyn_cast<gpu::LaunchFuncOp>(u)) {
deps.push_back(gpuLaunchOp.getAsyncToken());
}

if (lastUser->isBeforeInBlock(u)) {
lastUser = u;
}
}

rewriter.setInsertionPointAfter(lastUser);
gpu::DeallocOp::create(rewriter, op.getLoc(),
gpu::AsyncTokenType::get(rewriter.getContext()),
ValueRange(deps), allocOp.getResult());
}

funcOpInterface.eraseArguments(rewriteScratchMemoryIdxs);

// TODO: to be safe we should rework the other attributes if they are being
// removed...
rewriter.setInsertionPoint(op);
auto newJitCallOp = enzymexla::JITCallOp::create(
rewriter, op.getLoc(), op.getResultTypes(), op.getFn(), newInputs,
op.getBackendConfigAttr(), op.getOperandLayoutsAttr(),
op.getResultLayoutsAttr(), op.getArgAttrsAttr(), op.getResAttrsAttr(),
op.getOutputOperandAliasesAttr(), op.getXlaSideEffectFreeAttr());
rewriter.replaceOp(op, newJitCallOp);
return success();
}
};

struct LowerTritonExtensionOpsPass
: public mlir::enzyme::impl::LowerTritonExtensionOpsPassBase<
LowerTritonExtensionOpsPass> {
using Base::Base;

void runOnOperation() override {
auto context = getOperation()->getContext();

RewritePatternSet patterns(context);
patterns.add<JITCallScratchMemoryLowering>(context);

GreedyRewriteConfig config;
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config))) {
signalPassFailure();
}
}
};
48 changes: 48 additions & 0 deletions src/enzyme_ad/jax/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1008,4 +1008,52 @@ def SCFCPUify : Pass<"cpuify"> {
Option<"method", "method", "std::string", /*default=*/"\"distribute\"", "Method of doing distribution">
];
}

def ConvertTritonToTritonGPUPreservingModuleAttributesPass : Pass<
"convert-triton-to-triton-gpu-preserving-module-attributes", "mlir::ModuleOp"> {
let summary = "Triton generally compiles a single kernel, so they can specify the number of ctas and warps. However, we want to be able to compile multiple kernels. This pass will use the attributes from the module and use that to lower to TritonGPU.";
let dependentDialects = [];
let options = [
Option<
/*C++ variable name=*/"target",
/*CLI argument=*/"target",
/*type=*/"std::string",
/*default=*/"\"\"",
/*description=*/"the GPU target, e.g., cuda:80, hip:gfx942"
>];
}

def LowerTritonPass : Pass<"lower-triton", "mlir::ModuleOp"> {
let summary = "Lower Triton to kernel call";
let dependentDialects = [
"triton::TritonDialect",
"gpu::GPUDialect",
"enzymexla::EnzymeXLADialect",
"func::FuncDialect",
"enzymexla::triton_ext::TritonExtDialect",
"stablehlo::StablehloDialect",
];
}

def TritonAugmentFunctionWithExtraArgumentsPass : Pass<
"triton-augment-function-with-extra-arguments", "mlir::ModuleOp"> {
let dependentDialects = [
"triton::TritonDialect",
"func::FuncDialect",
"enzymexla::triton_ext::TritonExtDialect",
];
}

def LowerTritonExtensionOpsPass : Pass<"lower-triton-extension-ops"> {
let dependentDialects = [
"triton::TritonDialect",
"func::FuncDialect",
"LLVM::LLVMDialect",
"memref::MemRefDialect",
"enzymexla::EnzymeXLADialect",
"enzymexla::triton_ext::TritonExtDialect",
"gpu::GPUDialect",
];
}

#endif
Loading
Loading