Skip to content

Commit f16b622

Browse files
Implemented compile time/size tracking and profiling utility
1 parent ef210bc commit f16b622

File tree

5 files changed

+502
-104
lines changed

5 files changed

+502
-104
lines changed

.github/workflows/triton-benchmarks.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ env:
6868
VERIFY: ${{ (github.event_name == 'pull_request' || github.event_name == 'schedule' || inputs.verify) && '1' || '0' }}
6969
TAG: ${{ inputs.tag || (github.event_name == 'pull_request' && format('pr-{0}', github.event.number)) || (github.event_name == 'schedule' && 'ci') || 'test' }}
7070
N_RUNS: ${{ inputs.n_runs || '1' }}
71+
TRITON_TRACK_DUMP: "$PWD/reports/track"
7172

7273
jobs:
7374
build:

python/src/ir.cc

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1814,6 +1814,44 @@ void init_triton_ir(py::module &&m) {
18141814
self.printAsTextualPipeline(os);
18151815
return str;
18161816
})
1817+
.def("enable_timing",
1818+
[](PassManager &self, py::function cb) {
1819+
struct CallBackStrategy : OutputStrategy {
1820+
py::function cb;
1821+
1822+
CallBackStrategy(py::function cb)
1823+
: OutputStrategy(llvm::errs()), cb(cb) {}
1824+
1825+
void printHeader(const TimeRecord &total) override {}
1826+
1827+
void printFooter() override {}
1828+
1829+
void printTime(const TimeRecord &time,
1830+
const TimeRecord &total) override {}
1831+
1832+
void printListEntry(StringRef name, const TimeRecord &time,
1833+
const TimeRecord &total,
1834+
bool lastEntry = false) override {
1835+
cb(std::string(name), time.wall, 0);
1836+
}
1837+
1838+
void printTreeEntry(unsigned indent, StringRef name,
1839+
const TimeRecord &time,
1840+
const TimeRecord &total) override {
1841+
cb(std::string(name), time.wall, 1);
1842+
}
1843+
1844+
void printTreeEntryEnd(unsigned indent,
1845+
bool lastEntry = false) override {
1846+
cb(std::string(""), 0., 2);
1847+
}
1848+
};
1849+
1850+
auto tm = std::make_unique<mlir::DefaultTimingManager>();
1851+
tm->setOutput(std::make_unique<CallBackStrategy>(cb));
1852+
tm->setEnabled(true);
1853+
self.enableTiming(std::move(tm));
1854+
})
18171855
.def(
18181856
"run",
18191857
[](PassManager &self, ModuleOp &mod) {

third_party/intel/backend/compiler.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import os
1414
import subprocess
1515
from pathlib import Path
16+
from .track import track
1617

1718

1819
@dataclass
@@ -211,6 +212,7 @@ def get_split_barrier_scope(opt):
211212
return split_barriers_scope
212213

213214
@staticmethod
215+
@track
214216
def make_ttir(mod, metadata, opt):
215217
pm = ir.pass_manager(mod.context)
216218
pm.enable_debug()
@@ -230,6 +232,7 @@ def make_ttir(mod, metadata, opt):
230232
return mod
231233

232234
@staticmethod
235+
@track
233236
def make_ttgir(mod, metadata, opt, properties):
234237
cluster_info = intel.ClusterInfo()
235238
if opt.cluster_dims is not None:
@@ -307,6 +310,7 @@ def gluon_to_ttgir(self, src, metadata, options):
307310
return mod
308311

309312
@staticmethod
313+
@track
310314
def make_llir(src, metadata, options):
311315
mod = src
312316
# TritonGPU -> LLVM-IR (MLIR)
@@ -348,7 +352,9 @@ def make_llir(src, metadata, options):
348352
paths = [path for (name, path) in options.extern_libs]
349353
llvm.link_extern_libs(llvm_mod, paths)
350354

351-
intel.optimize_module(llvm_mod, llvm.OPTIMIZE_O3)
355+
with track("optimize_module") as tr:
356+
intel.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, tr.callback("passes"))
357+
352358
intel.post_process_llir(llvm_mod)
353359

354360
# Get some metadata
@@ -367,6 +373,7 @@ def make_llir(src, metadata, options):
367373
return ret
368374

369375
@staticmethod
376+
@track
370377
def make_spv(src, metadata, options, device_arch):
371378
spirv, name = intel.translate_to_spirv(src)
372379
metadata["name"] = name
@@ -394,7 +401,7 @@ def make_spv(src, metadata, options, device_arch):
394401
metadata["generate_native_code"] = options.generate_native_code
395402

396403
if options.generate_native_code:
397-
with tempfile.TemporaryDirectory() as temp_dir:
404+
with track("generate_native_code"), tempfile.TemporaryDirectory() as temp_dir:
398405
with tempfile.NamedTemporaryFile(mode='wb', suffix='.spv', dir=temp_dir, delete=False) as fsrc:
399406
fsrc.write(spirv)
400407
fbin = fsrc.name + '.o'

0 commit comments

Comments
 (0)