Skip to content

Commit fb68aea

Browse files
[IMP][Launch Latency] native specialize (#7771)
This PR is the fifth in a series of contributions aiming at reducing the launch overhead of Triton kernels ran without CUDA Graphs. A latency profiler script shared offline currently puts the main branch at a latency of around `27.16 us` (on a AMD EPYC 7413 24-C system with a H100-HBM3 GPU) which can be reduced via several contributions in different places. One remaining larger amount of time during launch is spent in `specialize_impl` and related calls which are necessary to process the signature and create the `specialization` for a kernel-cache lookup and eventually launching a kernel. Within that current logic, there are two things that cost time in particular - multiple calls to `specialize_impl` (each function call in Python is by `PyEval_EvalFrame` calls related to interpreting that function, doing some setup, and eventually GC afterwards) for each argument which can be up to 100ns per call - a surprisingly long time spent in calculating the alignment from native Python types This PR thus addresses these two issues in two major parts - "native" implementations of specializing integers and data-pointers which cuts down time spent in computing alignments and divisibility - a manual "inlining" of some of these specialization calls in `dynamic_func` which avoids some of the mentioned overheads from function calls above This PR also comes with two minor improvements accompanying these changes - slightly re-ordering the if/else conditions in the specialization logic - favoring types used more often - adding another branch for finding tensors based on its class name which should be faster than trying to access `data_ptr` Overall, this cuts down latency reported in the shared profiling script to `21.68 us` (from `27.16 us`). | name | PR | latency | reduction | |------|----|---------|-----------| | main | x | `27.16 us` | x | | cache-knob | #7767 | `25.10 us` | `2.06 us` | | native key | #7768 | `21.71 us` | `5.45 us` | | backend `GetAttrString` | #7769 | `26.80 us` | `0.36 us` | | misc compiler/kernel | #7770 | `23.90 us` | `3.26 us`| | native-specialize | #7771 | `21.68 us` | `5.48 us` | | **total** | x | `~10.5 us` | `16.61 us` | # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because `it should be covered by existing tests (no new functionality)`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) --------- Co-authored-by: peterbell10 <[email protected]> Co-authored-by: Peter Bell <[email protected]>
1 parent 8e87ed6 commit fb68aea

File tree

8 files changed

+596
-94
lines changed

8 files changed

+596
-94
lines changed

CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,8 @@ if(TRITON_BUILD_PYTHON_MODULE)
288288
${PYTHON_SRC_PATH}/gluon_ir.cc
289289
${PYTHON_SRC_PATH}/passes.cc
290290
${PYTHON_SRC_PATH}/interpreter.cc
291-
${PYTHON_SRC_PATH}/llvm.cc)
291+
${PYTHON_SRC_PATH}/llvm.cc
292+
${PYTHON_SRC_PATH}/specialize.cc)
292293

293294
# Link triton with its dependencies
294295
target_link_libraries(triton PRIVATE ${TRITON_LIBRARIES})

python/src/main.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,14 @@ void init_triton_interpreter(pybind11::module &&m);
4343
void init_triton_passes(pybind11::module &&m);
4444
void init_triton_stacktrace_hook(pybind11::module &m);
4545
void init_gluon_ir(pybind11::module &&m);
46+
void init_native_specialize(pybind11::module &m);
4647
FOR_EACH_P(DECLARE_BACKEND, TRITON_BACKENDS_TUPLE)
4748

4849
PYBIND11_MODULE(libtriton, m) {
4950
m.doc() = "Python bindings to the C++ Triton API";
5051
init_triton_stacktrace_hook(m);
5152
init_triton_env_vars(m);
53+
init_native_specialize(m);
5254
init_triton_ir(m.def_submodule("ir"));
5355
init_triton_passes(m.def_submodule("passes"));
5456
init_triton_interpreter(m.def_submodule("interpreter"));

0 commit comments

Comments
 (0)