You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[FRONTEND] Clone saved exception before raising (#8115)
triton-lang/triton#7857 has introduced
re-throwing the compilation error in subsequent `run` calls on the
`CompiledKernel` instance. To this end, `functools.partial` was used to
save the error to be raised within a closure of a function that then
replaces the `self.run` method. After being raised, the error saved in
the closure gets a `__traceback__` attached to it, with the latter
holding on to local variables in the stack frames. This is problematic,
because the `CompiledKernel` instance is then saved in the global kernel
cache which is maintained for the duration of the program. As a result,
the objects from the stack trace (e.g., Tensor inputs to the Triton
kernel) won't be freed leading to a memory leak.
This PR fixes the issue above by cloning the exception to be raised
*before* raising it. The cloning needs to be done both before creating
the closure with `functools.partial` and within the `_raise_error`
function, as if the saved exception instance is raised by
`_raise_error`, it will get a traceback attached to it leading to the
same problem.
P.S. triton-lang/triton#7857 has caused some CI
jobs failing in PyTorch CI. The error: CUDAGraph capture in PT2
complains about dangling tensors after the model run. Investigation has
pointed to the issue solved by this PR. For more details see the Triton
update tracker issue in PyTorch
pytorch/pytorch#159704.
0 commit comments