Skip to content

Commit 2b5fc74

Browse files
committed
[release/2.7] Tensor input dumping for triton kernels (#2716)
Related to one of the customer issues, but will be useful later
1 parent 55c9130 commit 2b5fc74

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

torch/_inductor/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1442,6 +1442,15 @@ class triton:
14421442
os.environ.get("TORCHINDUCTOR_DECOMPOSE_K_THRESHOLD", "32")
14431443
)
14441444

1445+
# Map for storing the amount of kernel runs with dumped imput tensors
1446+
# Based on hash of Triton source code to avoid bloating the folder
1447+
kernel_dump_occurency_map: dict[str, int] = {}
1448+
1449+
# Value for the maximum amount of runs with dumped kernel input tensors
1450+
# When the maximum is reached the first values get overwritten
1451+
# This ensures the last N runs are saved, where N is this value
1452+
max_kernel_dump_occurencies = 3
1453+
14451454

14461455
class aot_inductor:
14471456
"""

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from torch._environment import is_fbcode
3535
from torch._prims_common import compute_required_storage_length
3636
from torch.utils._ordered_set import OrderedSet
37+
from torch._inductor.config import triton as inuctor_triton_config
3738

3839
from ..triton_bundler import TritonBundler
3940
from ..utils import prefix_is_reduction, triton_version_uses_attrs_dict
@@ -223,6 +224,39 @@ def _dump_launch_params(args, kwargs, launcher, kernel_name, grid):
223224
f.write(f"{kernel_name} | {args_str} | {grid!r}\n")
224225

225226

227+
def _dump_launch_tensors(args, kernel_path, kernel_hash, kernel_name):
228+
tensor_list = [arg for arg in args if isinstance(arg, torch.Tensor)]
229+
230+
run_index = 0
231+
232+
# Some kernels don't have path and hash stored
233+
# Using only the name to differentiate between those
234+
if not kernel_path:
235+
kernel_hash = kernel_name
236+
237+
# Saving only the last N runs of the kernels to avoid bloating the folder
238+
if kernel_hash in inuctor_triton_config.kernel_dump_occurency_map:
239+
run_index = inuctor_triton_config.kernel_dump_occurency_map[kernel_hash] + 1
240+
241+
if run_index >= inuctor_triton_config.max_kernel_dump_occurencies:
242+
run_index = 0
243+
244+
inuctor_triton_config.kernel_dump_occurency_map[kernel_hash] = run_index
245+
246+
# Default path for kernels with no hash
247+
if not kernel_path:
248+
directory_path = "/tmp/torchinductor_root/unhashed_kernel_inputs"
249+
else:
250+
directory_path = os.path.dirname(kernel_path)
251+
directory_path = f"{directory_path}/{kernel_name}_run_{run_index}"
252+
os.makedirs(directory_path, exist_ok=True)
253+
254+
tensor_index = 0
255+
for tensor in tensor_list:
256+
torch.save(tensor, f"{directory_path}/tensor_{tensor_index}.pt")
257+
tensor_index +=1
258+
259+
226260
def check_autotune_cache(
227261
configs: list[Config], filename: Optional[str], inductor_meta: dict[str, Any]
228262
) -> tuple[list[Config], Optional[AutotuneCache], dict[str, Any]]:
@@ -367,6 +401,10 @@ def __init__(
367401
self.dump_launch_params = (
368402
os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_PARAMS", "0") == "1"
369403
)
404+
self.dump_launch_tensors = (
405+
os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_TENSORS", "0") == "1"
406+
)
407+
self.kernels_to_dump = os.environ.get("TORCHINDUCTOR_KERNELS_TO_DUMP", "").split(",")
370408

371409
self.triton_interpret = os.environ.get("TRITON_INTERPRET", "0") == "1"
372410

@@ -1291,6 +1329,11 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
12911329
new_args, grid = self._interpret_args_grid(args, launcher.config)
12921330
_dump_launch_params(new_args, kwargs, launcher, self.fn.__name__, grid)
12931331

1332+
if self.dump_launch_tensors:
1333+
# Check the kernel name if the list was provided
1334+
if not self.kernels_to_dump or any(kernel_name in self.fn.__name__ for kernel_name in self.kernels_to_dump):
1335+
_dump_launch_tensors(args, self.filename, self.kernel_hash, self.fn.__name__)
1336+
12941337
# it is faster than entering and exiting a context manager, even if the context
12951338
# manager is a nullcontext.
12961339
if autograd_profiler._is_profiler_enabled:

0 commit comments

Comments
 (0)