Skip to content

Commit 0ecd727

Browse files
NikhilAPatelfacebook-github-bot
authored andcommitted
[Inductor][Tritonparse] Get Inductor kernel params (pytorch#161953)
Summary: X-link: meta-pytorch/tritonparse#89 Pull Request resolved: pytorch#161953 Save the config args that Inductor burns into `inductor_metadata` so we can optionally pass them to any Jit Hooks that are set. This allows us to pass them to Tritonparse. Reviewed By: davidberard98, FindHao Differential Revision: D80994791
1 parent a99d8d3 commit 0ecd727

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -772,14 +772,25 @@ def _precompile_config(self, cfg: Config) -> CompileResult[_KernelType]:
772772
and getattr(knobs.runtime, "jit_post_compile_hook", None)
773773
):
774774
try:
775-
knobs.runtime.jit_post_compile_hook(
775+
hook = knobs.runtime.jit_post_compile_hook
776+
777+
# base args everyone should get
778+
call_kwargs = dict(
776779
key=getattr(self.fn, "cache_key", self.kernel_hash or str(self.fn)),
777780
repr=getattr(self.fn, "src", None),
778781
fn=self.fn,
779782
compile=binary,
780783
is_manual_warmup=False,
781784
already_compiled=True,
782785
)
786+
787+
# only add inductor_args if the hook takes it
788+
sig = inspect.signature(hook)
789+
params = sig.parameters
790+
if "inductor_args" in params:
791+
call_kwargs["inductor_args"] = self.inductor_meta["config_args"]
792+
793+
hook(**call_kwargs)
783794
except Exception:
784795
log.exception("jit_post_compile_hook failed")
785796

torch/_inductor/select_algorithm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,8 @@ def jit_lines(self):
613613
flops = self.estimate_flops()
614614
inductor_meta["kernel_flop"] = flops
615615

616+
inductor_meta["config_args"] = self.meta
617+
616618
template_args = f"""
617619
num_stages={self.num_stages},
618620
num_warps={self.num_warps},

0 commit comments

Comments
 (0)