diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 857272df14c9..cbe2bb44f6c8 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1442,6 +1442,15 @@ class triton: os.environ.get("TORCHINDUCTOR_DECOMPOSE_K_THRESHOLD", "32") ) + # Map for storing the amount of kernel runs with dumped imput tensors + # Based on hash of Triton source code to avoid bloating the folder + kernel_dump_occurency_map: dict[str, int] = {} + + # Value for the maximum amount of runs with dumped kernel input tensors + # When the maximum is reached the first values get overwritten + # This ensures the last N runs are saved, where N is this value + max_kernel_dump_occurencies = 3 + class aot_inductor: """ diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 1de1f9a595c9..ffcfb98a6bf3 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -34,6 +34,7 @@ from torch._environment import is_fbcode from torch._prims_common import compute_required_storage_length from torch.utils._ordered_set import OrderedSet +from torch._inductor.config import triton as inuctor_triton_config from ..triton_bundler import TritonBundler from ..utils import prefix_is_reduction, triton_version_uses_attrs_dict @@ -223,6 +224,39 @@ def _dump_launch_params(args, kwargs, launcher, kernel_name, grid): f.write(f"{kernel_name} | {args_str} | {grid!r}\n") +def _dump_launch_tensors(args, kernel_path, kernel_hash, kernel_name): + tensor_list = [arg for arg in args if isinstance(arg, torch.Tensor)] + + run_index = 0 + + # Some kernels don't have path and hash stored + # Using only the name to differentiate between those + if not kernel_path: + kernel_hash = kernel_name + + # Saving only the last N runs of the kernels to avoid bloating the folder + if kernel_hash in inuctor_triton_config.kernel_dump_occurency_map: + run_index = inuctor_triton_config.kernel_dump_occurency_map[kernel_hash] + 1 + + if run_index >= inuctor_triton_config.max_kernel_dump_occurencies: + run_index = 0 + + inuctor_triton_config.kernel_dump_occurency_map[kernel_hash] = run_index + + # Default path for kernels with no hash + if not kernel_path: + directory_path = "/tmp/torchinductor_root/unhashed_kernel_inputs" + else: + directory_path = os.path.dirname(kernel_path) + directory_path = f"{directory_path}/{kernel_name}_run_{run_index}" + os.makedirs(directory_path, exist_ok=True) + + tensor_index = 0 + for tensor in tensor_list: + torch.save(tensor, f"{directory_path}/tensor_{tensor_index}.pt") + tensor_index +=1 + + def check_autotune_cache( configs: list[Config], filename: Optional[str], inductor_meta: dict[str, Any] ) -> tuple[list[Config], Optional[AutotuneCache], dict[str, Any]]: @@ -367,6 +401,10 @@ def __init__( self.dump_launch_params = ( os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_PARAMS", "0") == "1" ) + self.dump_launch_tensors = ( + os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_TENSORS", "0") == "1" + ) + self.kernels_to_dump = os.environ.get("TORCHINDUCTOR_KERNELS_TO_DUMP", "").split(",") self.triton_interpret = os.environ.get("TRITON_INTERPRET", "0") == "1" @@ -1291,6 +1329,11 @@ def alloc_fn(size: int, align: int, stream: Optional[int]): new_args, grid = self._interpret_args_grid(args, launcher.config) _dump_launch_params(new_args, kwargs, launcher, self.fn.__name__, grid) + if self.dump_launch_tensors: + # Check the kernel name if the list was provided + if not self.kernels_to_dump or any(kernel_name in self.fn.__name__ for kernel_name in self.kernels_to_dump): + _dump_launch_tensors(args, self.filename, self.kernel_hash, self.fn.__name__) + # it is faster than entering and exiting a context manager, even if the context # manager is a nullcontext. if autograd_profiler._is_profiler_enabled: