Skip to content

Commit fde96e8

Browse files
authored
Remove redundant calls to inspect.getsource (#7588)
Triton calls `inspect.getsource/getsourcelines` a fair amount during compilation. The second call in `JITFunction.__init__` is completely redundant since `getsource` just calls `getsourcelines` under the hood, and caching the source lines on `JITFunction` also lets you skip the calls in `get_jit_fn_file_line`. `getsource` is quite expensive, so this saves >5% of the overall compile time. (Sadly, I think this is the last of the easy compile time wins.)
1 parent 981b0bb commit fde96e8

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

python/triton/runtime/jit.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,7 @@ def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_o
632632
self.signature = inspect.signature(fn)
633633
self.do_not_specialize = do_not_specialize
634634
self.do_not_specialize_on_alignment = do_not_specialize_on_alignment
635-
self.starting_line_number = inspect.getsourcelines(fn)[1]
635+
self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
636636
self._repr = repr
637637
self._fn_name = get_full_name(fn)
638638
self.launch_metadata = launch_metadata
@@ -644,7 +644,7 @@ def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_o
644644
self.params.append(KernelParam(i, param, dns, dns_oa))
645645

646646
# function source code (without decorators)
647-
src = textwrap.dedent(inspect.getsource(fn))
647+
src = textwrap.dedent("".join(self.raw_src))
648648
src = src[re.search(r"^def\s+\w+\s*\(", src, re.MULTILINE).start():]
649649
self._unsafe_update_src(src)
650650
# cache of just-in-time compiled kernels
@@ -977,13 +977,13 @@ def get_jit_fn_file_line(fn):
977977
while not isinstance(base_fn, JITFunction):
978978
base_fn = base_fn.fn
979979
file_name = base_fn.fn.__code__.co_filename
980-
lines, begin_line = inspect.getsourcelines(base_fn.fn)
980+
begin_line = base_fn.starting_line_number
981981
# Match the following pattern:
982982
# @triton.autotune(...) <- foo.__code__.co_firstlineno
983983
# @triton.heuristics(...)
984984
# @triton.jit
985985
# def foo(...): <- this line is the first line
986-
for idx, line in enumerate(lines):
986+
for idx, line in enumerate(base_fn.raw_src):
987987
if line.strip().startswith("def "):
988988
begin_line += idx
989989
break

0 commit comments

Comments
 (0)