Skip to content

Commit 2421984

Browse files
A tracking utility for gathering the compile and/or runtime time, size, profiling and other statistics (#4777)
To enable the tracking, set the environment variable ``TRITON_TRACK_DUMP`` to either ``1``, ``true``, ``yes``, ``on``, ``y`` or a path to a directory where the tracking reports will be dumped. To add the profiling statistics to the reports, set the ``TRITON_TRACK_PROFILE`` environment variable. To track the kernel launches, set the ``TRITON_TRACK_RUN`` environment variable. Link #4716
1 parent 363ee9f commit 2421984

File tree

4 files changed

+518
-104
lines changed

4 files changed

+518
-104
lines changed

python/src/ir.cc

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1866,6 +1866,44 @@ void init_triton_ir(py::module &&m) {
18661866
self.printAsTextualPipeline(os);
18671867
return str;
18681868
})
1869+
.def("enable_timing",
1870+
[](PassManager &self, py::function cb) {
1871+
struct CallBackStrategy : OutputStrategy {
1872+
py::function cb;
1873+
1874+
CallBackStrategy(py::function cb)
1875+
: OutputStrategy(llvm::errs()), cb(cb) {}
1876+
1877+
void printHeader(const TimeRecord &total) override {}
1878+
1879+
void printFooter() override {}
1880+
1881+
void printTime(const TimeRecord &time,
1882+
const TimeRecord &total) override {}
1883+
1884+
void printListEntry(StringRef name, const TimeRecord &time,
1885+
const TimeRecord &total,
1886+
bool lastEntry = false) override {
1887+
cb(std::string(name), time.wall, 0);
1888+
}
1889+
1890+
void printTreeEntry(unsigned indent, StringRef name,
1891+
const TimeRecord &time,
1892+
const TimeRecord &total) override {
1893+
cb(std::string(name), time.wall, 1);
1894+
}
1895+
1896+
void printTreeEntryEnd(unsigned indent,
1897+
bool lastEntry = false) override {
1898+
cb(std::string(""), 0., 2);
1899+
}
1900+
};
1901+
1902+
auto tm = std::make_unique<mlir::DefaultTimingManager>();
1903+
tm->setOutput(std::make_unique<CallBackStrategy>(cb));
1904+
tm->setEnabled(true);
1905+
self.enableTiming(std::move(tm));
1906+
})
18691907
.def(
18701908
"run",
18711909
[](PassManager &self, ModuleOp &mod, std::string repro_pipeline_tag) {

third_party/intel/backend/compiler.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from triton.backends.compiler import BaseBackend, Language
22
from triton._C.libtriton import ir, passes, llvm, intel
33
from triton.backends.intel.driver import compile_module_from_src
4+
from triton.backends.intel.track import track
45
from triton import knobs
56

67
from dataclasses import dataclass
@@ -190,6 +191,7 @@ def get_split_barrier_scope(opt):
190191
return split_barriers_scope
191192

192193
@staticmethod
194+
@track
193195
def make_ttir(mod, metadata, opt):
194196
pm = ir.pass_manager(mod.context)
195197
pm.enable_debug()
@@ -209,6 +211,7 @@ def make_ttir(mod, metadata, opt):
209211
return mod
210212

211213
@staticmethod
214+
@track
212215
def make_ttgir(mod, metadata, opt, properties):
213216
cluster_info = intel.ClusterInfo()
214217
if opt.cluster_dims is not None:
@@ -282,6 +285,7 @@ def gluon_to_ttgir(self, src, metadata, options):
282285
return mod
283286

284287
@staticmethod
288+
@track
285289
def make_llir(src, metadata, options):
286290
mod = src
287291
# TritonGPU -> LLVM-IR (MLIR)
@@ -341,7 +345,9 @@ def make_llir(src, metadata, options):
341345
paths = [path for (name, path) in options.extern_libs]
342346
llvm.link_extern_libs(llvm_mod, paths)
343347

344-
intel.optimize_module(llvm_mod, llvm.OPTIMIZE_O3)
348+
with track("optimize_module") as tr:
349+
intel.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, tr.callback("passes"))
350+
345351
intel.post_process_llir(llvm_mod)
346352

347353
# Get some metadata
@@ -360,6 +366,7 @@ def make_llir(src, metadata, options):
360366
return ret
361367

362368
@staticmethod
369+
@track
363370
def make_spv(src, metadata, options, device_arch):
364371
spirv, name = intel.translate_to_spirv(src)
365372
metadata["name"] = name
@@ -387,7 +394,7 @@ def make_spv(src, metadata, options, device_arch):
387394
metadata["generate_native_code"] = options.generate_native_code
388395

389396
if options.generate_native_code:
390-
with tempfile.TemporaryDirectory() as temp_dir:
397+
with track("generate_native_code"), tempfile.TemporaryDirectory() as temp_dir:
391398
with tempfile.NamedTemporaryFile(mode='wb', suffix='.spv', dir=temp_dir, delete=False) as fsrc:
392399
fsrc.write(spirv)
393400
fbin = fsrc.name + '.o'

0 commit comments

Comments
 (0)