Skip to content

Commit dd0efa6

Browse files
authored
[run_utils] Apply run options to config (#449)
1 parent 6bc8ac2 commit dd0efa6

File tree

4 files changed

+40
-16
lines changed

4 files changed

+40
-16
lines changed

benchmarks/dump_ir/run.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,27 @@
33
For autotuned operators, we save the IRs of the best kernels.
44
"""
55

6+
import argparse
67
import os
78
from pathlib import Path
8-
import argparse
9-
from tritonbench.utils.env_utils import is_fbcode
10-
from tritonbench.utils.run_utils import run_in_task, run_one_operator
11-
from tritonbench.operators import list_custom_triton_operators
129

13-
from typing import List, Dict
10+
from typing import Dict, List
11+
1412
from libfb.py import parutil
13+
from tritonbench.operators import list_custom_triton_operators
14+
from tritonbench.utils.env_utils import is_fbcode
15+
from tritonbench.utils.run_utils import run_in_task, run_one_operator
1516

16-
METADATA_DIR = parutil.get_file_path("tritonbench/metadata") if is_fbcode() \
17+
METADATA_DIR = (
18+
parutil.get_file_path("tritonbench/metadata")
19+
if is_fbcode()
1720
else Path(__file__).parent.parent.parent.joinpath("tritonbench/metadata")
21+
)
1822

1923
OSS_CUSTOM_TRITON_YAML = os.path.join(METADATA_DIR, "oss_triton_operators.yaml")
20-
INTERNAL_CUSTOM_TRITON_YAML = os.path.join(METADATA_DIR, "fb/internal_triton_operators.yaml")
24+
INTERNAL_CUSTOM_TRITON_YAML = os.path.join(
25+
METADATA_DIR, "fb/internal_triton_operators.yaml"
26+
)
2127

2228

2329
def get_parser():
@@ -29,25 +35,38 @@ def get_parser():
2935
help="Output directory to save the IRs",
3036
)
3137
parser.add_argument(
32-
"--run-in-task",
33-
action="store_true",
34-
help="indicate running in task."
38+
"--run-in-task", action="store_true", help="indicate running in task."
3539
)
3640
return parser
3741

3842

3943
def run_operator(op: str, subop: List[str], output_dir: str):
4044
"""Run a Tritonbench operator and save its IR to the specified directory"""
41-
opbench_args = ["--run-in-task", "--op", op, "--only", ",".join(subop), "--dump-ir", output_dir]
45+
opbench_args = [
46+
"--run-in-task",
47+
"--op",
48+
op,
49+
"--only",
50+
",".join(subop),
51+
"--dump-ir",
52+
output_dir,
53+
]
4254
run_in_task(op, opbench_args)
4355

56+
4457
if __name__ == "__main__":
4558
parser = get_parser()
4659
args, extra_args = parser.parse_known_args()
4760
if args.run_in_task:
4861
run_one_operator(extra_args, with_bwd=True)
4962
exit(0)
50-
custom_triton_op_yamls = [OSS_CUSTOM_TRITON_YAML, INTERNAL_CUSTOM_TRITON_YAML] if is_fbcode() else [OSS_CUSTOM_TRITON_YAML]
51-
operators: Dict[str, List[str]] = list_custom_triton_operators(custom_triton_op_yamls)
63+
custom_triton_op_yamls = (
64+
[OSS_CUSTOM_TRITON_YAML, INTERNAL_CUSTOM_TRITON_YAML]
65+
if is_fbcode()
66+
else [OSS_CUSTOM_TRITON_YAML]
67+
)
68+
operators: Dict[str, List[str]] = list_custom_triton_operators(
69+
custom_triton_op_yamls
70+
)
5271
[run_operator(op, operators[op].keys(), args.output_dir) for op in operators]
5372
print(f"[tritonbench][dump_ir] Result saved to {args.output_dir}")

run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def run(args: List[str] = []):
102102
if args == []:
103103
args = sys.argv[1:]
104104
if config := os.environ.get("TRITONBENCH_RUN_CONFIG", None):
105-
run_config(config)
105+
run_config(config, args)
106106
return
107107

108108
# Log the tool usage

tritonbench/operators/layer_norm/operator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def parse_op_args(args: List[str]):
4343

4444
try:
4545
from quack.quack_layernorm import layernorm as quack_layernorm
46+
4647
HAS_QUACK_KERNEL = True
4748
except ModuleNotFoundError:
4849
HAS_QUACK_KERNEL = False

tritonbench/utils/run_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,19 @@ def get_github_env() -> Dict[str, str]:
9898
return out
9999

100100

101-
def run_config(config_file: str):
101+
def run_config(config_file: str, args: List[str]):
102102
assert Path(config_file).exists(), f"Config file {config_file} must exist."
103103
with open(config_file, "r") as fp:
104104
config = yaml.safe_load(fp)
105105
for benchmark_name in config:
106106
benchmark_config = config[benchmark_name]
107107
op_name = benchmark_config["op"]
108-
op_args = benchmark_config["args"].split(" ")
108+
op_args = benchmark_config["args"].split(" ") + args
109109
env_string = benchmark_config.get("envs", None)
110+
disabled = benchmark_config.get("disabled", False)
111+
if disabled:
112+
logger.info(f"Skipping disabled benchmark {benchmark_name}.")
113+
continue
110114
extra_envs = {}
111115
if env_string:
112116
for env_part in env_string.split(" "):

0 commit comments

Comments
 (0)