Skip to content

Commit 56a9adf

Browse files
authored
Don't use lambdas in JITFunction to make it picklable (#5900)
PyTorch issue: pytorch/pytorch#146945 Functionality in PyTorch that started relying on serializability of `JITFunction`: pytorch/pytorch#146417 I suppose there are different ways to solve this problem, but at least the current lambdas are not necessary and can be easily rewritten. Signed-off-by: Anatoly Myachev <[email protected]>
1 parent de650ad commit 56a9adf

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

python/triton/runtime/jit.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,9 @@ def run(self, *args, grid, warmup, **kwargs):
587587
*bound_args.values())
588588
return kernel
589589

590+
def repr(self, _):
591+
return self._fn_name if self._repr is None else self._repr(_)
592+
590593
def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_on_alignment=None, debug=None,
591594
noinline=None, repr=None, launch_metadata=None):
592595
do_not_specialize = do_not_specialize if do_not_specialize else []
@@ -599,7 +602,8 @@ def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_o
599602
self.do_not_specialize = do_not_specialize
600603
self.do_not_specialize_on_alignment = do_not_specialize_on_alignment
601604
self.starting_line_number = inspect.getsourcelines(fn)[1]
602-
self.repr = lambda _: fn.__name__ if repr is None else repr(_)
605+
self._repr = repr
606+
self._fn_name = fn.__name__
603607
self.launch_metadata = launch_metadata
604608

605609
self.params = []
@@ -613,7 +617,7 @@ def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_o
613617
src = src[re.search(r"^def\s+\w+\s*\(", src, re.MULTILINE).start():]
614618
self._unsafe_update_src(src)
615619
# cache of just-in-time compiled kernels
616-
self.device_caches = defaultdict(lambda: self.create_binder())
620+
self.device_caches = defaultdict(self.create_binder)
617621
self.hash = None
618622

619623
# Map of global variables used by the function and any functions it

0 commit comments

Comments
 (0)