Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1468,6 +1468,15 @@ class triton:
# Programmatic Dependent Launch improves launch latency on Nvidia Hopper+ devices
enable_pdl = False

# Map for storing the amount of kernel runs with dumped imput tensors
# Based on hash of Triton source code to avoid bloating the folder
kernel_dump_occurency_map: dict[str, int] = {}

# Value for the maximum amount of runs with dumped kernel input tensors
# When the maximum is reached the first values get overwritten
# This ensures the last N runs are saved, where N is this value
max_kernel_dump_occurencies = 3


class aot_inductor:
"""
Expand Down
43 changes: 43 additions & 0 deletions torch/_inductor/runtime/triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from torch._environment import is_fbcode
from torch._prims_common import compute_required_storage_length
from torch.utils._ordered_set import OrderedSet
from torch._inductor.config import triton as inuctor_triton_config

from ..triton_bundler import TritonBundler
from ..utils import prefix_is_reduction, triton_version_uses_attrs_dict
Expand Down Expand Up @@ -223,6 +224,39 @@ def _dump_launch_params(args, kwargs, launcher, kernel_name, grid):
f.write(f"{kernel_name} | {args_str} | {grid!r}\n")


def _dump_launch_tensors(args, kernel_path, kernel_hash, kernel_name):
tensor_list = [arg for arg in args if isinstance(arg, torch.Tensor)]

run_index = 0

# Some kernels don't have path and hash stored
# Using only the name to differentiate between those
if not kernel_path:
kernel_hash = kernel_name

# Saving only the last N runs of the kernels to avoid bloating the folder
if kernel_hash in inuctor_triton_config.kernel_dump_occurency_map:
run_index = inuctor_triton_config.kernel_dump_occurency_map[kernel_hash] + 1

if run_index >= inuctor_triton_config.max_kernel_dump_occurencies:
run_index = 0

inuctor_triton_config.kernel_dump_occurency_map[kernel_hash] = run_index

# Default path for kernels with no hash
if not kernel_path:
directory_path = "/tmp/torchinductor_root/unhashed_kernel_inputs"
else:
directory_path = os.path.dirname(kernel_path)
directory_path = f"{directory_path}/{kernel_name}_run_{run_index}"
os.makedirs(directory_path, exist_ok=True)

tensor_index = 0
for tensor in tensor_list:
torch.save(tensor, f"{directory_path}/tensor_{tensor_index}.pt")
tensor_index +=1


def check_autotune_cache(
configs: list[Config], filename: Optional[str], inductor_meta: dict[str, Any]
) -> tuple[list[Config], Optional[AutotuneCache], dict[str, Any]]:
Expand Down Expand Up @@ -367,6 +401,10 @@ def __init__(
self.dump_launch_params = (
os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_PARAMS", "0") == "1"
)
self.dump_launch_tensors = (
os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_TENSORS", "0") == "1"
)
self.kernels_to_dump = os.environ.get("TORCHINDUCTOR_KERNELS_TO_DUMP", "").split(",")

self.triton_interpret = os.environ.get("TRITON_INTERPRET", "0") == "1"

Expand Down Expand Up @@ -1306,6 +1344,11 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
new_args, grid = self._interpret_args_grid(args, launcher.config)
_dump_launch_params(new_args, kwargs, launcher, self.fn.__name__, grid)

if self.dump_launch_tensors:
# Check the kernel name if the list was provided
if not self.kernels_to_dump or any(kernel_name in self.fn.__name__ for kernel_name in self.kernels_to_dump):
_dump_launch_tensors(args, self.filename, self.kernel_hash, self.fn.__name__)

# it is faster than entering and exiting a context manager, even if the context
# manager is a nullcontext.
if autograd_profiler._is_profiler_enabled:
Expand Down