@@ -109,9 +109,6 @@ def read_configs(
109109 postfixes = postfixes_tmp ,
110110 )
111111
112- if with_npbench :
113- fix_npbench_configs (config .benchmarks )
114-
115112 return config
116113
117114
@@ -232,6 +229,8 @@ def setup_init(config: Benchmark, modules: list[str]) -> None:
232229 init_module = None
233230 if config .module_name in modules :
234231 init_module = config .module_name
232+ elif config .short_name in modules :
233+ init_module = config .short_name
235234 elif config .module_name + "_initialize" in modules :
236235 init_module = config .module_name + "_initialize"
237236
@@ -251,6 +250,34 @@ def setup_init(config: Benchmark, modules: list[str]) -> None:
251250 )
252251
253252
253+ def discover_module_name_and_postfix (module : str , config : Config ):
254+ """Discover real module name and postfix for the implementation.
255+
256+ Args:
257+ module: Name of the root python module (either python file or top level
258+ folder for sycl).
259+ config: Module config.
260+
261+ Returns: (module_name, postfix).
262+ """
263+ postfix = ""
264+ module_name = ""
265+
266+ if module .endswith ("sycl_native_ext" ):
267+ module_name = (
268+ f"{ module } .{ config .module_name } _sycl._{ config .module_name } _sycl"
269+ )
270+ postfix = "sycl"
271+ else :
272+ module_name = module
273+ if module .startswith (config .module_name ):
274+ postfix = module [len (config .module_name ) + 1 :]
275+ elif module .startswith (config .short_name ):
276+ postfix = module [len (config .short_name ) + 1 :]
277+
278+ return module_name , postfix
279+
280+
254281def read_benchmark_implementations (
255282 config : Benchmark ,
256283 known_implementations : list [Implementation ],
@@ -289,17 +316,7 @@ def read_benchmark_implementations(
289316 setup_init (config , modules )
290317
291318 for module in modules :
292- postfix = ""
293- module_name = ""
294-
295- if module .endswith ("sycl_native_ext" ):
296- module_name = (
297- f"{ module } .{ config .module_name } _sycl._{ config .module_name } _sycl"
298- )
299- postfix = "sycl"
300- else :
301- module_name = module
302- postfix = module [len (config .module_name ) + 1 :]
319+ module_name , postfix = discover_module_name_and_postfix (module , config )
303320
304321 if postfixes and postfix not in postfixes :
305322 continue
@@ -318,6 +335,8 @@ def read_benchmark_implementations(
318335 impl_mod = importlib .import_module (package_path )
319336
320337 for func in [
338+ module ,
339+ f"{ module } _{ postfix } " ,
321340 config .module_name ,
322341 f"{ config .module_name } _{ postfix } " ,
323342 "kernel" ,
@@ -349,72 +368,3 @@ def get_benchmark_index(configs: list[Benchmark], module_name: str) -> int:
349368 ),
350369 None ,
351370 )
352-
353-
354- def fix_npbench_configs (configs : list [Benchmark ]):
355- """Applies configuration fixes for some npbench benchmarks.
356-
357- Fixes required due to the difference in framework implementations.
358- """
359- index = get_benchmark_index (configs , "mandelbrot1" )
360- if index is not None :
361- configs [index ] = modify_args (
362- configs [index ], modifier = lambda s : s .lower ()
363- )
364-
365- index = get_benchmark_index (configs , "mandelbrot2" )
366- if index is not None :
367- configs [index ] = modify_args (
368- configs [index ],
369- modifier = lambda s : "itermax" if s == "maxiter" else s .lower (),
370- )
371-
372- index = get_benchmark_index (configs , "conv2d" )
373- if index is not None :
374- config = configs [index ]
375-
376- config .module_name = "conv2d_bias"
377- configs [index ] = config
378-
379- for impl in config .implementations :
380- impl .func_name = "conv2d_bias"
381-
382- index = get_benchmark_index (configs , "nbody" )
383- if index is not None :
384- configs [index ].output_args .append ("pos" )
385- configs [index ].output_args .append ("vel" )
386-
387- index = get_benchmark_index (configs , "scattering_self_energies" )
388- if index is not None :
389- configs [index ].output_args .append ("Sigma" )
390-
391- index = get_benchmark_index (configs , "correlation" )
392- if index is not None :
393- configs [index ].output_args .append ("data" )
394-
395- index = get_benchmark_index (configs , "doitgen" )
396- if index is not None :
397- configs [index ].output_args .append ("A" )
398-
399-
400- def modify_args (config : Benchmark , modifier : Callable [[str ], str ]) -> Benchmark :
401- """Applies modifier to function argument names.
402-
403- Current implementation applies modifier to
404- - all presets keys, not preset names;
405- - all input_args;
406- - all array_args;
407- - all output_args.
408- """
409- config .parameters = Presets (
410- {
411- preset : {modifier (k ): v for k , v in parameters .items ()}
412- for preset , parameters in config .parameters .items ()
413- }
414- )
415-
416- config .input_args = [modifier (arg ) for arg in config .input_args ]
417- config .array_args = [modifier (arg ) for arg in config .array_args ]
418- config .output_args = [modifier (arg ) for arg in config .output_args ]
419-
420- return config
0 commit comments