Skip to content

Commit de2e53a

Browse files
authored
[release/2.7] Tensor input dumping for triton kernels (#2716)
Related to one of the customer issues, but will be useful later
1 parent 6110a92 commit de2e53a

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

torch/_inductor/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,6 +1162,15 @@ class triton:
11621162
# Note: it may also need to be used with config.compile_threads = 1
11631163
disallow_failing_autotune_kernels_TESTING_ONLY = False
11641164

1165+
# Map for storing the amount of kernel runs with dumped imput tensors
1166+
# Based on hash of Triton source code to avoid bloating the folder
1167+
kernel_dump_occurency_map: dict[str, int] = {}
1168+
1169+
# Value for the maximum amount of runs with dumped kernel input tensors
1170+
# When the maximum is reached the first values get overwritten
1171+
# This ensures the last N runs are saved, where N is this value
1172+
max_kernel_dump_occurencies = 3
1173+
11651174

11661175
class aot_inductor:
11671176
# AOTInductor output path

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import torch
2424
from torch._prims_common import compute_required_storage_length
2525
from torch.utils._ordered_set import OrderedSet
26+
from torch._inductor.config import triton as inuctor_triton_config
2627

2728
from ..triton_bundler import TritonBundler
2829
from ..utils import prefix_is_reduction, triton_version_uses_attrs_dict
@@ -164,6 +165,39 @@ def _dump_launch_params(args, kwargs, launcher, kernel_name, grid):
164165
f.write(f"{kernel_name} | {args_str} | {grid!r}\n")
165166

166167

168+
def _dump_launch_tensors(args, kernel_path, kernel_hash, kernel_name):
169+
tensor_list = [arg for arg in args if isinstance(arg, torch.Tensor)]
170+
171+
run_index = 0
172+
173+
# Some kernels don't have path and hash stored
174+
# Using only the name to differentiate between those
175+
if not kernel_path:
176+
kernel_hash = kernel_name
177+
178+
# Saving only the last N runs of the kernels to avoid bloating the folder
179+
if kernel_hash in inuctor_triton_config.kernel_dump_occurency_map:
180+
run_index = inuctor_triton_config.kernel_dump_occurency_map[kernel_hash] + 1
181+
182+
if run_index >= inuctor_triton_config.max_kernel_dump_occurencies:
183+
run_index = 0
184+
185+
inuctor_triton_config.kernel_dump_occurency_map[kernel_hash] = run_index
186+
187+
# Default path for kernels with no hash
188+
if not kernel_path:
189+
directory_path = "/tmp/torchinductor_root/unhashed_kernel_inputs"
190+
else:
191+
directory_path = os.path.dirname(kernel_path)
192+
directory_path = f"{directory_path}/{kernel_name}_run_{run_index}"
193+
os.makedirs(directory_path, exist_ok=True)
194+
195+
tensor_index = 0
196+
for tensor in tensor_list:
197+
torch.save(tensor, f"{directory_path}/tensor_{tensor_index}.pt")
198+
tensor_index +=1
199+
200+
167201
class CachingAutotuner(KernelInterface):
168202
"""
169203
Simplified version of Triton autotuner that has no invalidation
@@ -255,6 +289,10 @@ def __init__(
255289
self.dump_launch_params = (
256290
os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_PARAMS", "0") == "1"
257291
)
292+
self.dump_launch_tensors = (
293+
os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_TENSORS", "0") == "1"
294+
)
295+
self.kernels_to_dump = os.environ.get("TORCHINDUCTOR_KERNELS_TO_DUMP", "").split(",")
258296

259297
self.triton_interpret = os.environ.get("TRITON_INTERPRET", "0") == "1"
260298

@@ -925,6 +963,11 @@ def run(
925963
new_args, grid = self._interpret_args_grid(args, launcher.config)
926964
_dump_launch_params(new_args, kwargs, launcher, self.fn.__name__, grid)
927965

966+
if self.dump_launch_tensors:
967+
# Check the kernel name if the list was provided
968+
if not self.kernels_to_dump or any(kernel_name in self.fn.__name__ for kernel_name in self.kernels_to_dump):
969+
_dump_launch_tensors(args, self.filename, self.kernel_hash, self.fn.__name__)
970+
928971
# it is faster than entering and exiting a context manager, even if the context
929972
# manager is a nullcontext.
930973
if autograd_profiler._is_profiler_enabled:

0 commit comments

Comments
 (0)