Skip to content

Commit f73e9a4

Browse files
agron911meta-codesync[bot]
authored andcommitted
[triton][beta] [Cherry-pick] '[BACKEND] Add hook for configurable/overridable compiler pass pipeline (#8137)' (#1014)
Summary: Pull Request resolved: #1014 This is a cherry-pick of an upstream PR: triton-lang/triton#8137 Upstream commit message: ``` > [BACKEND] Add hook for configurable/overridable compiler pass pipeline (#8137) > Triton’s existing pass pipelines are explicitly defined in the various > extended compiler.py files that live in Triton’s backends. Currently > when we require insertion of passes either for instrumentation or for > the addition of downstream optimization and custom lowering it is > required for the compiler.py file itself to be modified. > In order to allow for more downstream configurability and as a first > step toward more custom MLIR level pass plugins, we add a hook into the > compiler stages to allow for a more configurable pass manager system > setup. > Using Python inspection routines coupled with the hook allows for more > fine grained control of things like enabling/disabling passes for > specific kernels with eventually being able to load and insert > completely out of tree ops/passes in arbitrary places in the stages > pipeline. > Co-authored with plotfi > --------- > Co-authored-by: Puyan Lotfi <puyan@puyan.org> ``` ***Do not remove the following line from this commit*** Reactor Cherry-pick Revision: 905b3d1 --- This diff was generated by running: ``` buck run fbcode//triton/tools/reactor:reactor -- cherrypick --num-commits 1 ``` Reviewed By: dshi7 Differential Revision: D94678547 fbshipit-source-id: a2af93c1274f56130c21b2b91621b7fea3763121
1 parent 345c368 commit f73e9a4

File tree

5 files changed

+75
-0
lines changed

5 files changed

+75
-0
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,15 @@ qk_storage_alias.set_buffer_overlap(
133133
)
134134
```
135135

136+
**Compiler Pipeline Inspection Steps**
137+
To introspect the pipeline `add_stages`, before running your kernels, simply set
138+
the add_stages_inspection_hook like so:
139+
140+
```python
141+
def inspect_stages(_self, stages, options, language, capability):
142+
# inspect or modify add_stages here
143+
triton.knobs.runtime.add_stages_inspection_hook = inspect_stages
144+
```
136145

137146

138147
### Remote buffer operations
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import triton
2+
from triton import knobs
3+
4+
import os
5+
import pathlib
6+
7+
8+
def test_inspection(monkeypatch, tmp_path: pathlib.Path):
9+
stage_name = 'make_ttgir'
10+
curr_repro_path = tmp_path / ("repro_prefix." + stage_name + ".repro.mlir")
11+
repro_path = tmp_path / "repro_prefix"
12+
13+
monkeypatch.setenv("TRITON_ALWAYS_COMPILE", "1")
14+
monkeypatch.setenv("TRITON_REPRODUCER_PATH", str(repro_path))
15+
16+
inspect_stages_hook_called = False
17+
make_ttgir_wrapper_called = False
18+
19+
def inspect_stages_hook(self, stages, options, language, capability):
20+
nonlocal inspect_stages_hook_called
21+
inspect_stages_hook_called = True
22+
23+
def make_ttgir_wrapper(src, metadata, options, capability):
24+
nonlocal make_ttgir_wrapper_called
25+
make_ttgir_wrapper_called = True
26+
return self.make_ttgir(src, metadata, options, capability)
27+
28+
stages["ttgir"] = lambda src, metadata: make_ttgir_wrapper(src, metadata, options, capability)
29+
30+
@triton.jit
31+
def k1():
32+
return
33+
34+
@triton.jit
35+
def k2():
36+
return
37+
38+
# Run once to get the clean/golden repro dump
39+
k1[(1, )]()
40+
assert not inspect_stages_hook_called and not make_ttgir_wrapper_called
41+
assert os.path.exists(curr_repro_path)
42+
golden_repro = curr_repro_path.read_text()
43+
curr_repro_path.unlink()
44+
45+
# Setup hook and call again, check if hooks got called
46+
knobs.runtime.add_stages_inspection_hook = inspect_stages_hook
47+
k2[(1, )]()
48+
assert inspect_stages_hook_called and make_ttgir_wrapper_called
49+
assert os.path.exists(curr_repro_path)
50+
hook_repro = curr_repro_path.read_text()
51+
52+
# Check that repros match
53+
assert golden_repro.replace('k1', 'dummy') == hook_repro.replace('k2', 'dummy')

python/triton/knobs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,12 @@ def __call__(self, *, key: str, repr: str, fn: JitFunctionInfo, compile: JITHook
455455
...
456456

457457

458+
class PipelineStagesHook(Protocol):
459+
460+
def __call__(self, stages, options, language, capability):
461+
...
462+
463+
458464
class runtime_knobs(base_knobs):
459465
interpret: env_bool = env_bool("TRITON_INTERPRET")
460466
# debug is on critical path for kernel launches
@@ -475,6 +481,9 @@ class runtime_knobs(base_knobs):
475481
# jit_cache_hook will always be called before compilation and jit_post_compile_hook after.
476482
jit_post_compile_hook: Optional[JITHook] = None
477483

484+
# Hook for inspecting compiler pipeline stages
485+
add_stages_inspection_hook: Optional[PipelineStagesHook] = None
486+
478487

479488
class language_knobs(base_knobs):
480489
fp32_default: env_opt_str = env_opt_str("TRITON_F32_DEFAULT")

third_party/amd/backend/compiler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,8 @@ def add_stages(self, stages, options, language):
501501
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
502502
stages["amdgcn"] = lambda src, metadata: self.make_amdgcn(src, metadata, options)
503503
stages["hsaco"] = lambda src, metadata: self.make_hsaco(src, metadata, options)
504+
if knobs.runtime.add_stages_inspection_hook is not None:
505+
knobs.runtime.add_stages_inspection_hook(self, stages, options, language, None)
504506

505507
@functools.lru_cache()
506508
def hash(self):

third_party/nvidia/backend/compiler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,8 @@ def add_stages(self, stages, options, language):
617617
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability)
618618
stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.target.arch)
619619
stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.target.arch)
620+
if knobs.runtime.add_stages_inspection_hook is not None:
621+
knobs.runtime.add_stages_inspection_hook(self, stages, options, language, capability)
620622

621623
@functools.lru_cache()
622624
def hash(self):

0 commit comments

Comments
 (0)