1+ import concurrent .futures
12import functools
23import inspect
34import itertools
1617def build (premake , constexpr_param_grid , caller , op_name , output_dir ):
1718 headers = []
1819 all_param_names = []
20+ combinations = []
1921 launches = []
2022
21- for combination in _generate_param_value_combinations ( constexpr_param_grid ) :
22- arrangement , application , tensors = premake ( ** combination )
23+ with concurrent . futures . ProcessPoolExecutor () as executor :
24+ futures = []
2325
24- for param_name , param_value in combination .items ():
25- if isinstance (param_value , str ):
26- combination [param_name ] = (
27- f"INFINI_DTYPE_{ combination [param_name ].replace ('fp' , 'F' ).upper ()} "
28- )
26+ for combination in tuple (
27+ _generate_param_value_combinations (constexpr_param_grid )
28+ ):
29+ future = executor .submit (
30+ _make , premake , combination , caller , op_name , output_dir
31+ )
2932
30- combination = { f" { name } _" : value for name , value in combination . items ()}
33+ futures . append ( future )
3134
32- kernel_name = f"{ op_name } _{ _generate_suffix (combination .values ())} "
35+ for future in concurrent .futures .as_completed (futures ):
36+ header , param_names , combination , launch = future .result ()
3337
34- ninetoothed .make (
35- arrangement ,
36- application ,
37- tensors ,
38- caller = caller ,
39- kernel_name = kernel_name ,
40- output_dir = output_dir ,
41- )
42-
43- header = output_dir / f"{ kernel_name } .h"
44- param_names = ("stream" ,) + tuple (
45- inspect .signature (application ).parameters .keys ()
46- )
47- launch = f""" if ({ _generate_condition (combination )} )
48- return launch_{ kernel_name } ({ ", " .join (param_names )} );"""
49-
50- headers .append (header )
51- all_param_names .append (param_names )
52- launches .append (launch )
38+ headers .append (header )
39+ all_param_names .append (param_names )
40+ combinations .append (combination )
41+ launches .append (launch )
5342
5443 includes = "\n " .join (f'#include "{ header } "' for header in headers )
5544
@@ -64,7 +53,7 @@ def build(premake, constexpr_param_grid, caller, op_name, output_dir):
6453 "NineToothedStream" ,
6554 ] + ["NineToothedTensor" for _ in range (len (param_names ) - 1 )]
6655
67- for param_name in combination :
56+ for param_name in functools . reduce ( lambda x , y : x | y , combinations , {}) :
6857 param_names .append (param_name )
6958 param_types .append ("int" )
7059
@@ -97,6 +86,36 @@ def build(premake, constexpr_param_grid, caller, op_name, output_dir):
9786 (BUILD_DIRECTORY_PATH / header_file_name ).write_text (header_content )
9887
9988
89+ def _make (premake , combination , caller , op_name , output_dir ):
90+ arrangement , application , tensors = premake (** combination )
91+
92+ for param_name , param_value in combination .items ():
93+ if isinstance (param_value , str ):
94+ combination [param_name ] = (
95+ f"INFINI_DTYPE_{ combination [param_name ].replace ('fp' , 'F' ).upper ()} "
96+ )
97+
98+ combination = {f"{ name } _" : value for name , value in combination .items ()}
99+
100+ kernel_name = f"{ op_name } _{ _generate_suffix (combination .values ())} "
101+
102+ ninetoothed .make (
103+ arrangement ,
104+ application ,
105+ tensors ,
106+ caller = caller ,
107+ kernel_name = kernel_name ,
108+ output_dir = output_dir ,
109+ )
110+
111+ header = output_dir / f"{ kernel_name } .h"
112+ param_names = ("stream" ,) + tuple (inspect .signature (application ).parameters .keys ())
113+ launch = f""" if ({ _generate_condition (combination )} )
114+ return launch_{ kernel_name } ({ ", " .join (param_names )} );"""
115+
116+ return header , param_names , combination , launch
117+
118+
100119def _generate_condition (combination ):
101120 return " && " .join (f"{ param } == { value } " for param , value in combination .items ())
102121
0 commit comments