Skip to content

Commit 55be9d9

Browse files
Implemented compile time/size tracking and profiling utility
1 parent 0ab03be commit 55be9d9

File tree

5 files changed

+498
-104
lines changed

5 files changed

+498
-104
lines changed

.github/workflows/triton-benchmarks.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ env:
6868
VERIFY: ${{ (github.event_name == 'pull_request' || github.event_name == 'schedule' || inputs.verify) && '1' || '0' }}
6969
TAG: ${{ inputs.tag || (github.event_name == 'pull_request' && format('pr-{0}', github.event.number)) || (github.event_name == 'schedule' && 'ci') || 'test' }}
7070
N_RUNS: ${{ inputs.n_runs || '1' }}
71+
TRITON_TRACK_DUMP: "$PWD/reports/track"
7172

7273
jobs:
7374
build:

python/src/ir.cc

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1803,6 +1803,44 @@ void init_triton_ir(py::module &&m) {
18031803
self.printAsTextualPipeline(os);
18041804
return str;
18051805
})
1806+
.def("enable_timing",
1807+
[](PassManager &self, py::function cb) {
1808+
struct CallBackStrategy : OutputStrategy {
1809+
py::function cb;
1810+
1811+
CallBackStrategy(py::function cb)
1812+
: OutputStrategy(llvm::errs()), cb(cb) {}
1813+
1814+
void printHeader(const TimeRecord &total) override {}
1815+
1816+
void printFooter() override {}
1817+
1818+
void printTime(const TimeRecord &time,
1819+
const TimeRecord &total) override {}
1820+
1821+
void printListEntry(StringRef name, const TimeRecord &time,
1822+
const TimeRecord &total,
1823+
bool lastEntry = false) override {
1824+
cb(std::string(name), time.wall, 0);
1825+
}
1826+
1827+
void printTreeEntry(unsigned indent, StringRef name,
1828+
const TimeRecord &time,
1829+
const TimeRecord &total) override {
1830+
cb(std::string(name), time.wall, 1);
1831+
}
1832+
1833+
void printTreeEntryEnd(unsigned indent,
1834+
bool lastEntry = false) override {
1835+
cb(std::string(""), 0., 2);
1836+
}
1837+
};
1838+
1839+
auto tm = std::make_unique<mlir::DefaultTimingManager>();
1840+
tm->setOutput(std::make_unique<CallBackStrategy>(cb));
1841+
tm->setEnabled(true);
1842+
self.enableTiming(std::move(tm));
1843+
})
18061844
.def(
18071845
"run",
18081846
[](PassManager &self, ModuleOp &mod) {

third_party/intel/backend/compiler.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import os
1414
import subprocess
1515
from pathlib import Path
16+
from .track import track
1617

1718

1819
@dataclass
@@ -207,6 +208,7 @@ def get_split_barrier_scope(opt):
207208
return split_barriers_scope
208209

209210
@staticmethod
211+
@track
210212
def make_ttir(mod, metadata, opt):
211213
pm = ir.pass_manager(mod.context)
212214
pm.enable_debug()
@@ -226,6 +228,7 @@ def make_ttir(mod, metadata, opt):
226228
return mod
227229

228230
@staticmethod
231+
@track
229232
def make_ttgir(mod, metadata, opt, properties):
230233
cluster_info = intel.ClusterInfo()
231234
if opt.cluster_dims is not None:
@@ -303,6 +306,7 @@ def gluon_to_ttgir(self, src, metadata, options):
303306
return mod
304307

305308
@staticmethod
309+
@track
306310
def make_llir(src, metadata, options):
307311
mod = src
308312
# TritonGPU -> LLVM-IR (MLIR)
@@ -340,7 +344,9 @@ def make_llir(src, metadata, options):
340344
paths = [path for (name, path) in options.extern_libs]
341345
llvm.link_extern_libs(llvm_mod, paths)
342346

343-
intel.optimize_module(llvm_mod, llvm.OPTIMIZE_O3)
347+
with track("optimize_module") as tr:
348+
intel.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, tr.callback("passes"))
349+
344350
intel.post_process_llir(llvm_mod)
345351

346352
# Get some metadata
@@ -357,6 +363,7 @@ def make_llir(src, metadata, options):
357363
return ret
358364

359365
@staticmethod
366+
@track
360367
def make_spv(src, metadata, options, device_arch):
361368
spirv, name = intel.translate_to_spirv(src)
362369
metadata["name"] = name
@@ -384,7 +391,7 @@ def make_spv(src, metadata, options, device_arch):
384391
metadata["generate_native_code"] = options.generate_native_code
385392

386393
if options.generate_native_code:
387-
with tempfile.TemporaryDirectory() as temp_dir:
394+
with track("generate_native_code"), tempfile.TemporaryDirectory() as temp_dir:
388395
with tempfile.NamedTemporaryFile(mode='wb', suffix='.spv', dir=temp_dir, delete=False) as fsrc:
389396
fsrc.write(spirv)
390397
fbin = fsrc.name + '.o'

0 commit comments

Comments
 (0)