Skip to content

Commit 6fa1dd6

Browse files
authored
[FRONTEND] Clone saved exception before raising (#8115)
triton-lang/triton#7857 has introduced re-throwing the compilation error in subsequent `run` calls on the `CompiledKernel` instance. To this end, `functools.partial` was used to save the error to be raised within a closure of a function that then replaces the `self.run` method. After being raised, the error saved in the closure gets a `__traceback__` attached to it, with the latter holding on to local variables in the stack frames. This is problematic, because the `CompiledKernel` instance is then saved in the global kernel cache which is maintained for the duration of the program. As a result, the objects from the stack trace (e.g., Tensor inputs to the Triton kernel) won't be freed leading to a memory leak. This PR fixes the issue above by cloning the exception to be raised *before* raising it. The cloning needs to be done both before creating the closure with `functools.partial` and within the `_raise_error` function, as if the saved exception instance is raised by `_raise_error`, it will get a traceback attached to it leading to the same problem. P.S. triton-lang/triton#7857 has caused some CI jobs failing in PyTorch CI. The error: CUDAGraph capture in PT2 complains about dangling tensors after the model run. Investigation has pointed to the issue solved by this PR. For more details see the Triton update tracker issue in PyTorch pytorch/pytorch#159704.
1 parent 4f5f43e commit 6fa1dd6

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

python/triton/compiler/compiler.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import functools
1616
import os
1717
import time
18+
import copy
1819

1920
# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
2021
# and any following whitespace
@@ -404,7 +405,7 @@ def __missing__(self, key):
404405

405406

406407
def _raise_error(err, *args, **kwargs):
407-
raise err
408+
raise copy.deepcopy(err)
408409

409410

410411
class CompiledKernel:
@@ -445,7 +446,13 @@ def _init_handles(self):
445446
return
446447

447448
def raise_(err):
448-
self._run = functools.partial(_raise_error, err)
449+
# clone the exception object so that the one saved in the closure
450+
# of the partial function below doesn't get assigned a stack trace
451+
# after the subsequent raise. otherwise, the CompiledKernel instance
452+
# saved in the (global) kernel cache will keep references to all the
453+
# locals in the traceback via the exception instance in the closure.
454+
cloned_err = copy.deepcopy(err)
455+
self._run = functools.partial(_raise_error, cloned_err)
449456
raise err
450457

451458
device = driver.active.get_current_device()

0 commit comments

Comments
 (0)