33import importlib .util
44import sys
55from argparse import ArgumentParser
6+ from dataclasses import dataclass
67from pathlib import Path
78from typing import List
89
910import triton
1011import triton .backends
11- from triton .backends .nvidia .driver import ty_to_cpp
12+
13+
14+ @dataclass
15+ class CompileArgs :
16+ '''
17+ A class to contain arguments from command-line parser.
18+ '''
19+ path : str = ''
20+ kernel_name : str = ''
21+ signature : str = ''
22+ grid : str = ''
23+ target : str | None = None
24+ num_warps : int = 1
25+ num_stages : int = 3
26+ out_name : str | None = None
27+ out_path : Path | None = None
28+
1229
1330desc = """
1431Triton ahead-of-time compiler:
3653used to run this `compile.py` script
3754"""
3855
39- if __name__ == "__main__" :
4056
57+ def main ():
4158 # command-line arguments
4259 parser = ArgumentParser (description = desc )
4360 parser .add_argument ("path" ,
4461 help = "Path to Python source containing desired kernel in its scope. File will be executed." )
4562 parser .add_argument ("--kernel-name" , "-n" , type = str , default = "" , help = "Name of the kernel to compile" ,
4663 required = True )
64+ parser .add_argument (
65+ "--target" , "-t" , type = str , default = None ,
66+ help = "The target to compile towards, in format of '<backend>:<arch>:<warp-size>'; "
67+ "e.g., 'cuda:80:32', 'hip:gfx942:64'. Default to None, which means using current machine's GPU target" )
4768 parser .add_argument ("--num-warps" , "-w" , type = int , default = 1 , help = "Number of warps to launch the kernel" )
4869 parser .add_argument ("--num-stages" , "-ns" , type = int , default = 3 ,
4970 help = "Number of stages (meta-parameter of the kernel)" )
5071 parser .add_argument ("--out-name" , "-on" , type = str , default = None , help = "Out name for the compiled kernel" )
5172 parser .add_argument ("--out-path" , "-o" , type = Path , default = None , help = "Out filename" )
5273 parser .add_argument ("--signature" , "-s" , type = str , help = "Signature of the kernel" , required = True )
5374 parser .add_argument ("--grid" , "-g" , type = str , help = "Launch grid of the kernel" , required = True )
54- args = parser .parse_args ()
75+ cli_args = parser .parse_args ()
76+ args = CompileArgs (** vars (cli_args )) # A sanity check to ensure class CompileArgs is updated as well.
77+ compile_kernel (args )
5578
79+
80+ def compile_kernel (args : CompileArgs ):
5681 out_name = args .out_name if args .out_name else args .kernel_name
5782 out_path = args .out_path if args .out_path else Path (out_name )
5883
@@ -108,9 +133,15 @@ def constexpr(s):
108133 assert h in [1 , 16 ], f"Only 1 and 16 are valid hints, got { h } "
109134 attrs = {k : [["tt.divisibility" , 16 ]] for k , v in hints .items () if v == 16 }
110135 src = triton .compiler .ASTSource (fn = kernel , constexprs = constants , signature = signature , attrs = attrs )
111- opts = {"num_warps" : args .num_warps , "num_stages" : args .num_stages }
112- ccinfo = triton .compile (src , options = opts )
113- if ccinfo .metadata .global_scratch_size > 0 :
136+
137+ target = triton .backends .compiler .GPUTarget (* args .target .split (":" )) \
138+ if args .target else triton .runtime .driver .active .get_current_target ()
139+ backend = triton .compiler .make_backend (target )
140+ kwargs = {"num_warps" : args .num_warps , "num_stages" : args .num_stages }
141+ options = backend .parse_options (kwargs )
142+ ccinfo = triton .compile (src , target = target , options = options .__dict__ )
143+
144+ if getattr (ccinfo .metadata , "global_scratch_size" , 0 ) > 0 :
114145 raise RuntimeError ("AOT compiling kernels with global scratch requirements is not yet implemented" )
115146
116147 arg_names = []
@@ -136,8 +167,12 @@ def constexpr(s):
136167 if hints .get ((i , ), None ) == 16 :
137168 suffix += 'd'
138169 func_name = '_' .join ([out_name , sig_hash , suffix ])
139- asm = ccinfo .asm ["cubin" ] # store binary data once
170+ asm = ccinfo .asm [backend .binary_ext ] # store binary data once
171+
140172 hex_ = str (binascii .hexlify (asm ))[2 :- 1 ]
173+
174+ ty_to_cpp = triton .runtime .driver .active .map_python_to_cpp_type
175+
141176 params = {
142177 "kernel_name" : func_name ,
143178 "triton_kernel_name" : args .kernel_name ,
@@ -156,7 +191,18 @@ def constexpr(s):
156191 "gridZ" : grid [2 ],
157192 "_placeholder" : "" ,
158193 }
159- for ext in ['h' , 'c' ]:
160- template_path = Path (__file__ ).parent / "extra" / "cuda" / f"compile.{ ext } "
161- with out_path .with_suffix (f".{ sig_hash } _{ suffix } .{ ext } " ).open ("w" ) as fp :
162- fp .write (Path (template_path ).read_text ().format (** params ))
194+ output_files = []
195+ backend_name = target .backend
196+ template_dir = Path (__file__ ).parent / "extra" / backend_name
197+ for template_path in template_dir .glob ('compile.*' ):
198+ ext = template_path .suffix
199+ output_file = out_path .with_suffix (f".{ sig_hash } _{ suffix } { ext } " )
200+ with output_file .open ("w" ) as fp :
201+ fp .write (template_path .read_text ().format (** params ))
202+ output_files .append (output_file )
203+
204+ return func_name , output_files
205+
206+
207+ if __name__ == "__main__" :
208+ main ()
0 commit comments