|
34 | 34 | from torch._environment import is_fbcode |
35 | 35 | from torch._prims_common import compute_required_storage_length |
36 | 36 | from torch.utils._ordered_set import OrderedSet |
| 37 | +from torch._inductor.config import triton as inuctor_triton_config |
37 | 38 |
|
38 | 39 | from ..triton_bundler import TritonBundler |
39 | 40 | from ..utils import prefix_is_reduction, triton_version_uses_attrs_dict |
@@ -223,6 +224,39 @@ def _dump_launch_params(args, kwargs, launcher, kernel_name, grid): |
223 | 224 | f.write(f"{kernel_name} | {args_str} | {grid!r}\n") |
224 | 225 |
|
225 | 226 |
|
| 227 | +def _dump_launch_tensors(args, kernel_path, kernel_hash, kernel_name): |
| 228 | + tensor_list = [arg for arg in args if isinstance(arg, torch.Tensor)] |
| 229 | + |
| 230 | + run_index = 0 |
| 231 | + |
| 232 | + # Some kernels don't have path and hash stored |
| 233 | + # Using only the name to differentiate between those |
| 234 | + if not kernel_path: |
| 235 | + kernel_hash = kernel_name |
| 236 | + |
| 237 | + # Saving only the last N runs of the kernels to avoid bloating the folder |
| 238 | + if kernel_hash in inuctor_triton_config.kernel_dump_occurency_map: |
| 239 | + run_index = inuctor_triton_config.kernel_dump_occurency_map[kernel_hash] + 1 |
| 240 | + |
| 241 | + if run_index >= inuctor_triton_config.max_kernel_dump_occurencies: |
| 242 | + run_index = 0 |
| 243 | + |
| 244 | + inuctor_triton_config.kernel_dump_occurency_map[kernel_hash] = run_index |
| 245 | + |
| 246 | + # Default path for kernels with no hash |
| 247 | + if not kernel_path: |
| 248 | + directory_path = "/tmp/torchinductor_root/unhashed_kernel_inputs" |
| 249 | + else: |
| 250 | + directory_path = os.path.dirname(kernel_path) |
| 251 | + directory_path = f"{directory_path}/{kernel_name}_run_{run_index}" |
| 252 | + os.makedirs(directory_path, exist_ok=True) |
| 253 | + |
| 254 | + tensor_index = 0 |
| 255 | + for tensor in tensor_list: |
| 256 | + torch.save(tensor, f"{directory_path}/tensor_{tensor_index}.pt") |
| 257 | + tensor_index +=1 |
| 258 | + |
| 259 | + |
226 | 260 | def check_autotune_cache( |
227 | 261 | configs: list[Config], filename: Optional[str], inductor_meta: dict[str, Any] |
228 | 262 | ) -> tuple[list[Config], Optional[AutotuneCache], dict[str, Any]]: |
@@ -367,6 +401,10 @@ def __init__( |
367 | 401 | self.dump_launch_params = ( |
368 | 402 | os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_PARAMS", "0") == "1" |
369 | 403 | ) |
| 404 | + self.dump_launch_tensors = ( |
| 405 | + os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_TENSORS", "0") == "1" |
| 406 | + ) |
| 407 | + self.kernels_to_dump = os.environ.get("TORCHINDUCTOR_KERNELS_TO_DUMP", "").split(",") |
370 | 408 |
|
371 | 409 | self.triton_interpret = os.environ.get("TRITON_INTERPRET", "0") == "1" |
372 | 410 |
|
@@ -1291,6 +1329,11 @@ def alloc_fn(size: int, align: int, stream: Optional[int]): |
1291 | 1329 | new_args, grid = self._interpret_args_grid(args, launcher.config) |
1292 | 1330 | _dump_launch_params(new_args, kwargs, launcher, self.fn.__name__, grid) |
1293 | 1331 |
|
| 1332 | + if self.dump_launch_tensors: |
| 1333 | + # Check the kernel name if the list was provided |
| 1334 | + if not self.kernels_to_dump or any(kernel_name in self.fn.__name__ for kernel_name in self.kernels_to_dump): |
| 1335 | + _dump_launch_tensors(args, self.filename, self.kernel_hash, self.fn.__name__) |
| 1336 | + |
1294 | 1337 | # it is faster than entering and exiting a context manager, even if the context |
1295 | 1338 | # manager is a nullcontext. |
1296 | 1339 | if autograd_profiler._is_profiler_enabled: |
|
0 commit comments