Skip to content

Commit 8df3bb6

Browse files
NikhilAPatelfacebook-github-bot
authored andcommitted
Get Inductor kernel params (#89)
Summary: X-link: pytorch/pytorch#161953 Save the config args that Inductor burns into `inductor_metadata` so we can optionally pass them to any Jit Hooks that are set. This allows us to pass them to Tritonparse. Reviewed By: davidberard98, FindHao Differential Revision: D80994791
1 parent fa5db13 commit 8df3bb6

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

tritonparse/structured_logging.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from dataclasses import asdict, is_dataclass
1616
from datetime import date, datetime
1717
from enum import Enum
18+
from functools import partial
1819
from pathlib import Path
1920
from typing import Any, Callable, Dict, List, Optional, Union
2021

@@ -819,10 +820,18 @@ def extract_arg_info(arg_dict):
819820
return extracted_args
820821

821822

822-
def add_launch_metadata(grid, metadata, arg_dict):
823+
def add_launch_metadata(grid, metadata, arg_dict, inductor_args=None):
823824
# Extract detailed argument information
824825
extracted_args = extract_arg_info(arg_dict)
825-
return {"launch_metadata_tritonparse": (grid, metadata._asdict(), extracted_args)}
826+
extracted_inductor_args = extract_arg_info(inductor_args) if inductor_args else {}
827+
return {
828+
"launch_metadata_tritonparse": (
829+
grid,
830+
metadata._asdict(),
831+
extracted_args,
832+
extracted_inductor_args,
833+
)
834+
}
826835

827836

828837
class JITHookImpl(JITHook):
@@ -848,6 +857,7 @@ def __call__(
848857
compile,
849858
is_manual_warmup: bool,
850859
already_compiled: bool,
860+
inductor_args: Optional[Dict[str, Any]] = None,
851861
) -> Optional[bool]:
852862
"""
853863
Override or set the launch_metadata function for the JIT-compiled kernel.
@@ -882,7 +892,9 @@ def __call__(
882892
log.warning(
883893
f"fn {fn} launch_metadata is not None: {current_launch_metadata}. It will be overridden by tritonparse."
884894
)
885-
function.launch_metadata = add_launch_metadata
895+
function.launch_metadata = partial(
896+
add_launch_metadata, inductor_args=inductor_args
897+
)
886898
return True
887899

888900

@@ -946,6 +958,7 @@ def __call__(self, metadata):
946958
trace_data["extracted_args"] = launch_metadata_tritonparse[
947959
2
948960
] # Now contains detailed arg info
961+
trace_data["extracted_inductor_args"] = launch_metadata_tritonparse[3]
949962
trace_structured_triton("launch", metadata_fn=lambda: convert(trace_data))
950963

951964

0 commit comments

Comments
 (0)