Skip to content

Commit 5832eca

Browse files
authored
feat: add a pass to convert tt to ttg preserving attributes (#1604)
1 parent 980931d commit 5832eca

File tree

2 files changed

+79
-0
lines changed

2 files changed

+79
-0
lines changed
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#include "src/enzyme_ad/jax/Passes/Passes.h"
2+
3+
#include "mlir/IR/PatternMatch.h"
4+
#include "mlir/Pass/Pass.h"
5+
#include "mlir/Pass/PassManager.h"
6+
7+
#include "triton/Conversion/TritonToTritonGPU/Passes.h"
8+
9+
#define DEBUG_TYPE "convert-triton-to-triton-gpu-preserving-module-attributes"
10+
11+
namespace mlir {
12+
namespace enzyme {
13+
#define GEN_PASS_DEF_CONVERTTRITONTOTRITONGPUPRESERVINGMODULEATTRIBUTESPASS
14+
#include "src/enzyme_ad/jax/Passes/Passes.h.inc"
15+
} // namespace enzyme
16+
} // namespace mlir
17+
18+
using namespace mlir;
19+
using namespace mlir::enzyme;
20+
21+
struct ConvertTritonToTritonGPUPreservingModuleAttributesPass
22+
: public mlir::enzyme::impl::
23+
ConvertTritonToTritonGPUPreservingModuleAttributesPassBase<
24+
ConvertTritonToTritonGPUPreservingModuleAttributesPass> {
25+
using Base::Base;
26+
27+
void runOnOperation() override {
28+
ModuleOp mod = getOperation();
29+
30+
int32_t numWarps = 4, threadsPerWarp = 32, numCtas = 1;
31+
bool enableSourceRemat = false;
32+
33+
if (mod->hasAttr("enzymexla.ttg.num-ctas")) {
34+
numCtas =
35+
mod->getAttrOfType<IntegerAttr>("enzymexla.ttg.num-ctas").getInt();
36+
}
37+
38+
if (mod->hasAttr("enzymexla.ttg.num-warps")) {
39+
numWarps =
40+
mod->getAttrOfType<IntegerAttr>("enzymexla.ttg.num-warps").getInt();
41+
}
42+
43+
if (mod->hasAttr("enzymexla.ttg.threads-per-warp")) {
44+
threadsPerWarp =
45+
mod->getAttrOfType<IntegerAttr>("enzymexla.ttg.threads-per-warp")
46+
.getInt();
47+
}
48+
49+
if (mod->hasAttr("enzymexla.ttg.enable-source-remat")) {
50+
enableSourceRemat = true;
51+
}
52+
53+
OpPassManager pm;
54+
pm.addPass(triton::createConvertTritonToTritonGPU(
55+
{target, numWarps, threadsPerWarp, numCtas, enableSourceRemat}));
56+
if (failed(runPipeline(pm, mod))) {
57+
mod->emitError() << "failed to run triton passes";
58+
signalPassFailure();
59+
return;
60+
}
61+
62+
return;
63+
}
64+
};

src/enzyme_ad/jax/Passes/Passes.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,4 +1023,19 @@ def SCFCPUify : Pass<"cpuify"> {
10231023
Option<"method", "method", "std::string", /*default=*/"\"distribute\"", "Method of doing distribution">
10241024
];
10251025
}
1026+
1027+
def ConvertTritonToTritonGPUPreservingModuleAttributesPass : Pass<
1028+
"convert-triton-to-triton-gpu-preserving-module-attributes", "mlir::ModuleOp"> {
1029+
let summary = "Triton generally compiles a single kernel, so they can specify the number of ctas and warps. However, we want to be able to compile multiple kernels. This pass will use the attributes from the module and use that to lower to TritonGPU.";
1030+
let dependentDialects = [];
1031+
let options = [
1032+
Option<
1033+
/*C++ variable name=*/"target",
1034+
/*CLI argument=*/"target",
1035+
/*type=*/"std::string",
1036+
/*default=*/"\"\"",
1037+
/*description=*/"the GPU target, e.g., cuda:80, hip:gfx942"
1038+
>];
1039+
}
1040+
10261041
#endif

0 commit comments

Comments
 (0)