3
3
For autotuned operators, we save the IRs of the best kernels.
4
4
"""
5
5
6
+ import argparse
6
7
import os
7
8
from pathlib import Path
8
- import argparse
9
- from tritonbench .utils .env_utils import is_fbcode
10
- from tritonbench .utils .run_utils import run_in_task , run_one_operator
11
- from tritonbench .operators import list_custom_triton_operators
12
9
13
- from typing import List , Dict
10
+ from typing import Dict , List
11
+
14
12
from libfb .py import parutil
13
+ from tritonbench .operators import list_custom_triton_operators
14
+ from tritonbench .utils .env_utils import is_fbcode
15
+ from tritonbench .utils .run_utils import run_in_task , run_one_operator
15
16
16
- METADATA_DIR = parutil .get_file_path ("tritonbench/metadata" ) if is_fbcode () \
17
+ METADATA_DIR = (
18
+ parutil .get_file_path ("tritonbench/metadata" )
19
+ if is_fbcode ()
17
20
else Path (__file__ ).parent .parent .parent .joinpath ("tritonbench/metadata" )
21
+ )
18
22
19
23
OSS_CUSTOM_TRITON_YAML = os .path .join (METADATA_DIR , "oss_triton_operators.yaml" )
20
- INTERNAL_CUSTOM_TRITON_YAML = os .path .join (METADATA_DIR , "fb/internal_triton_operators.yaml" )
24
+ INTERNAL_CUSTOM_TRITON_YAML = os .path .join (
25
+ METADATA_DIR , "fb/internal_triton_operators.yaml"
26
+ )
21
27
22
28
23
29
def get_parser ():
@@ -29,25 +35,38 @@ def get_parser():
29
35
help = "Output directory to save the IRs" ,
30
36
)
31
37
parser .add_argument (
32
- "--run-in-task" ,
33
- action = "store_true" ,
34
- help = "indicate running in task."
38
+ "--run-in-task" , action = "store_true" , help = "indicate running in task."
35
39
)
36
40
return parser
37
41
38
42
39
43
def run_operator (op : str , subop : List [str ], output_dir : str ):
40
44
"""Run a Tritonbench operator and save its IR to the specified directory"""
41
- opbench_args = ["--run-in-task" , "--op" , op , "--only" , "," .join (subop ), "--dump-ir" , output_dir ]
45
+ opbench_args = [
46
+ "--run-in-task" ,
47
+ "--op" ,
48
+ op ,
49
+ "--only" ,
50
+ "," .join (subop ),
51
+ "--dump-ir" ,
52
+ output_dir ,
53
+ ]
42
54
run_in_task (op , opbench_args )
43
55
56
+
44
57
if __name__ == "__main__" :
45
58
parser = get_parser ()
46
59
args , extra_args = parser .parse_known_args ()
47
60
if args .run_in_task :
48
61
run_one_operator (extra_args , with_bwd = True )
49
62
exit (0 )
50
- custom_triton_op_yamls = [OSS_CUSTOM_TRITON_YAML , INTERNAL_CUSTOM_TRITON_YAML ] if is_fbcode () else [OSS_CUSTOM_TRITON_YAML ]
51
- operators : Dict [str , List [str ]] = list_custom_triton_operators (custom_triton_op_yamls )
63
+ custom_triton_op_yamls = (
64
+ [OSS_CUSTOM_TRITON_YAML , INTERNAL_CUSTOM_TRITON_YAML ]
65
+ if is_fbcode ()
66
+ else [OSS_CUSTOM_TRITON_YAML ]
67
+ )
68
+ operators : Dict [str , List [str ]] = list_custom_triton_operators (
69
+ custom_triton_op_yamls
70
+ )
52
71
[run_operator (op , operators [op ].keys (), args .output_dir ) for op in operators ]
53
72
print (f"[tritonbench][dump_ir] Result saved to { args .output_dir } " )
0 commit comments