|
23 | 23 | import torch |
24 | 24 | from torch._prims_common import compute_required_storage_length |
25 | 25 | from torch.utils._ordered_set import OrderedSet |
| 26 | +from torch._inductor.config import triton as inuctor_triton_config |
26 | 27 |
|
27 | 28 | from ..triton_bundler import TritonBundler |
28 | 29 | from ..utils import prefix_is_reduction, triton_version_uses_attrs_dict |
@@ -164,6 +165,39 @@ def _dump_launch_params(args, kwargs, launcher, kernel_name, grid): |
164 | 165 | f.write(f"{kernel_name} | {args_str} | {grid!r}\n") |
165 | 166 |
|
166 | 167 |
|
| 168 | +def _dump_launch_tensors(args, kernel_path, kernel_hash, kernel_name): |
| 169 | + tensor_list = [arg for arg in args if isinstance(arg, torch.Tensor)] |
| 170 | + |
| 171 | + run_index = 0 |
| 172 | + |
| 173 | + # Some kernels don't have path and hash stored |
| 174 | + # Using only the name to differentiate between those |
| 175 | + if not kernel_path: |
| 176 | + kernel_hash = kernel_name |
| 177 | + |
| 178 | + # Saving only the last N runs of the kernels to avoid bloating the folder |
| 179 | + if kernel_hash in inuctor_triton_config.kernel_dump_occurency_map: |
| 180 | + run_index = inuctor_triton_config.kernel_dump_occurency_map[kernel_hash] + 1 |
| 181 | + |
| 182 | + if run_index >= inuctor_triton_config.max_kernel_dump_occurencies: |
| 183 | + run_index = 0 |
| 184 | + |
| 185 | + inuctor_triton_config.kernel_dump_occurency_map[kernel_hash] = run_index |
| 186 | + |
| 187 | + # Default path for kernels with no hash |
| 188 | + if not kernel_path: |
| 189 | + directory_path = "/tmp/torchinductor_root/unhashed_kernel_inputs" |
| 190 | + else: |
| 191 | + directory_path = os.path.dirname(kernel_path) |
| 192 | + directory_path = f"{directory_path}/{kernel_name}_run_{run_index}" |
| 193 | + os.makedirs(directory_path, exist_ok=True) |
| 194 | + |
| 195 | + tensor_index = 0 |
| 196 | + for tensor in tensor_list: |
| 197 | + torch.save(tensor, f"{directory_path}/tensor_{tensor_index}.pt") |
| 198 | + tensor_index +=1 |
| 199 | + |
| 200 | + |
167 | 201 | class CachingAutotuner(KernelInterface): |
168 | 202 | """ |
169 | 203 | Simplified version of Triton autotuner that has no invalidation |
@@ -255,6 +289,10 @@ def __init__( |
255 | 289 | self.dump_launch_params = ( |
256 | 290 | os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_PARAMS", "0") == "1" |
257 | 291 | ) |
| 292 | + self.dump_launch_tensors = ( |
| 293 | + os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_TENSORS", "0") == "1" |
| 294 | + ) |
| 295 | + self.kernels_to_dump = os.environ.get("TORCHINDUCTOR_KERNELS_TO_DUMP", "").split(",") |
258 | 296 |
|
259 | 297 | self.triton_interpret = os.environ.get("TRITON_INTERPRET", "0") == "1" |
260 | 298 |
|
@@ -925,6 +963,11 @@ def run( |
925 | 963 | new_args, grid = self._interpret_args_grid(args, launcher.config) |
926 | 964 | _dump_launch_params(new_args, kwargs, launcher, self.fn.__name__, grid) |
927 | 965 |
|
| 966 | + if self.dump_launch_tensors: |
| 967 | + # Check the kernel name if the list was provided |
| 968 | + if not self.kernels_to_dump or any(kernel_name in self.fn.__name__ for kernel_name in self.kernels_to_dump): |
| 969 | + _dump_launch_tensors(args, self.filename, self.kernel_hash, self.fn.__name__) |
| 970 | + |
928 | 971 | # it is faster than entering and exiting a context manager, even if the context |
929 | 972 | # manager is a nullcontext. |
930 | 973 | if autograd_profiler._is_profiler_enabled: |
|
0 commit comments