Skip to content

Commit 905b3d1

Browse files
CRobeckplotfi
andauthored
[BACKEND] Add hook for configurable/overridable compiler pass pipeline (#8137)
Triton’s existing pass pipelines are explicitly defined in the various extended compiler.py files that live in Triton’s backends. Currently when we require insertion of passes either for instrumentation or for the addition of downstream optimization and custom lowering it is required for the compiler.py file itself to be modified. In order to allow for more downstream configurability and as a first step toward more custom MLIR level pass plugins, we add a hook into the compiler stages to allow for a more configurable pass manager system setup. Using Python inspection routines coupled with the hook allows for more fine grained control of things like enabling/disabling passes for specific kernels with eventually being able to load and insert completely out of tree ops/passes in arbitrary places in the stages pipeline. Co-authored with @plotfi --------- Co-authored-by: Puyan Lotfi <[email protected]>
1 parent 7d92894 commit 905b3d1

File tree

5 files changed

+75
-0
lines changed

5 files changed

+75
-0
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,15 @@ export TRITON_OVERRIDE_DIR=<override_dir>
262262
# Step 4: Run the kernel again to see the overridden result
263263
```
264264

265+
**Compiler Pipeline Inspection Steps**
266+
To introspect the pipeline `add_stages`, before running your kernels, simply set
267+
the add_stages_inspection_hook like so:
268+
269+
```python
270+
def inspect_stages(_self, stages, options, language, capability):
271+
# inspect or modify add_stages here
272+
triton.knobs.runtime.add_stages_inspection_hook = inspect_stages
273+
```
265274

266275
# Changelog
267276

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import triton
2+
from triton import knobs
3+
4+
import os
5+
import pathlib
6+
7+
8+
def test_inspection(monkeypatch, tmp_path: pathlib.Path):
9+
stage_name = 'make_ttgir'
10+
curr_repro_path = tmp_path / ("repro_prefix." + stage_name + ".repro.mlir")
11+
repro_path = tmp_path / "repro_prefix"
12+
13+
monkeypatch.setenv("TRITON_ALWAYS_COMPILE", "1")
14+
monkeypatch.setenv("TRITON_REPRODUCER_PATH", str(repro_path))
15+
16+
inspect_stages_hook_called = False
17+
make_ttgir_wrapper_called = False
18+
19+
def inspect_stages_hook(self, stages, options, language, capability):
20+
nonlocal inspect_stages_hook_called
21+
inspect_stages_hook_called = True
22+
23+
def make_ttgir_wrapper(src, metadata, options, capability):
24+
nonlocal make_ttgir_wrapper_called
25+
make_ttgir_wrapper_called = True
26+
return self.make_ttgir(src, metadata, options, capability)
27+
28+
stages["ttgir"] = lambda src, metadata: make_ttgir_wrapper(src, metadata, options, capability)
29+
30+
@triton.jit
31+
def k1():
32+
return
33+
34+
@triton.jit
35+
def k2():
36+
return
37+
38+
# Run once to get the clean/golden repro dump
39+
k1[(1, )]()
40+
assert not inspect_stages_hook_called and not make_ttgir_wrapper_called
41+
assert os.path.exists(curr_repro_path)
42+
golden_repro = curr_repro_path.read_text()
43+
curr_repro_path.unlink()
44+
45+
# Setup hook and call again, check if hooks got called
46+
knobs.runtime.add_stages_inspection_hook = inspect_stages_hook
47+
k2[(1, )]()
48+
assert inspect_stages_hook_called and make_ttgir_wrapper_called
49+
assert os.path.exists(curr_repro_path)
50+
hook_repro = curr_repro_path.read_text()
51+
52+
# Check that repros match
53+
assert golden_repro.replace('k1', 'dummy') == hook_repro.replace('k2', 'dummy')

python/triton/knobs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,12 @@ def __call__(self, *, key: str, repr: str, fn: JitFunctionInfo, compile: JITHook
447447
...
448448

449449

450+
class PipelineStagesHook(Protocol):
451+
452+
def __call__(self, stages, options, language, capability):
453+
...
454+
455+
450456
class runtime_knobs(base_knobs):
451457
interpret: env_bool = env_bool("TRITON_INTERPRET")
452458
# debug is on critical path for kernel launches
@@ -465,6 +471,9 @@ class runtime_knobs(base_knobs):
465471
# jit_cache_hook will always be called before compilation and jit_post_compile_hook after.
466472
jit_post_compile_hook: Optional[JITHook] = None
467473

474+
# Hook for inspecting compiler pipeline stages
475+
add_stages_inspection_hook: Optional[PipelineStagesHook] = None
476+
468477

469478
class language_knobs(base_knobs):
470479
fp32_default: env_opt_str = env_opt_str("TRITON_F32_DEFAULT")

third_party/amd/backend/compiler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,8 @@ def add_stages(self, stages, options, language):
456456
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
457457
stages["amdgcn"] = lambda src, metadata: self.make_amdgcn(src, metadata, options)
458458
stages["hsaco"] = lambda src, metadata: self.make_hsaco(src, metadata, options)
459+
if knobs.runtime.add_stages_inspection_hook is not None:
460+
knobs.runtime.add_stages_inspection_hook(self, stages, options, language, None)
459461

460462
@functools.lru_cache()
461463
def hash(self):

third_party/nvidia/backend/compiler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,8 @@ def add_stages(self, stages, options, language):
519519
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability)
520520
stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.target.arch)
521521
stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.target.arch)
522+
if knobs.runtime.add_stages_inspection_hook is not None:
523+
knobs.runtime.add_stages_inspection_hook(self, stages, options, language, capability)
522524

523525
@functools.lru_cache()
524526
def hash(self):

0 commit comments

Comments
 (0)