@@ -27,6 +27,15 @@ def get_reference_output_path(reference_dir, model_path):
2727 return os .path .join (reference_dir , f"{ model_name } .pth" )
2828
2929
30+ def register_op_lib (op_lib ):
31+ if op_lib == "flaggems" :
32+ import flag_gems
33+
34+ flag_gems .enable ()
35+ else :
36+ pass
37+
38+
3039def test_single_model (args ):
3140 ref_log = get_reference_log_path (args .reference_dir , args .model_path )
3241 ref_dump = get_reference_output_path (args .reference_dir , args .model_path )
@@ -51,6 +60,10 @@ def test_single_model(args):
5160 test_compiler .get_compile_framework_version (args ),
5261 )
5362
63+ test_compiler_util .print_with_log_prompt (
64+ "[Config] op_lib:" , args .op_lib , args .log_prompt
65+ )
66+
5467 success = False
5568 time_stats = {}
5669 try :
@@ -99,6 +112,7 @@ def test_multi_models(args):
99112 f"--model-path { model_path } " ,
100113 f"--compiler { args .compiler } " ,
101114 f"--device { args .device } " ,
115+ f"--op-lib { args .op_lib } " ,
102116 f"--warmup { args .warmup } " ,
103117 f"--trials { args .trials } " ,
104118 f"--log-prompt { args .log_prompt } " ,
@@ -136,6 +150,7 @@ def main(args):
136150 ref_dump_dir .mkdir (parents = True , exist_ok = True )
137151
138152 if path_utils .is_single_model_dir (args .model_path ):
153+ register_op_lib (args .op_lib )
139154 test_single_model (args )
140155 else :
141156 test_multi_models (args )
@@ -163,6 +178,13 @@ def main(args):
163178 default = "cuda" ,
164179 help = "Device for testing the compiler (e.g., 'cpu' or 'cuda')" ,
165180 )
181+ parser .add_argument (
182+ "--op-lib" ,
183+ type = str ,
184+ required = False ,
185+ default = "default" ,
186+ help = "Customized operator library (eg. default, flaggems)" ,
187+ )
166188 parser .add_argument (
167189 "--warmup" , type = int , required = False , default = 5 , help = "Number of warmup steps"
168190 )
0 commit comments