Skip to content

Commit c367d0f

Browse files
committed
Move ref impl to config
1 parent 5e151fc commit c367d0f

File tree

3 files changed

+47
-29
lines changed

3 files changed

+47
-29
lines changed

dpbench/config/benchmark.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ class Benchmark:
8888
array_args: List[str] = field(default_factory=list)
8989
output_args: List[str] = field(default_factory=list)
9090
implementations: List[BenchmarkImplementation] = field(default_factory=list)
91+
reference_implementation_postfix: str = None
9192
expected_failure_implementations: List[str] = field(default_factory=list)
9293

9394
@staticmethod
@@ -108,6 +109,9 @@ def from_dict(obj: Any) -> "Benchmark":
108109
_array_args = obj.get("array_args") or []
109110
_output_args = obj.get("output_args") or []
110111
_implementations = obj.get("implementations") or []
112+
_reference_implementation_postfix = obj.get(
113+
"reference_implementation_postfix"
114+
)
111115
_expected_failure_implementations = (
112116
obj.get("expected_failure_implementations") or []
113117
)
@@ -126,5 +130,6 @@ def from_dict(obj: Any) -> "Benchmark":
126130
_array_args,
127131
_output_args,
128132
_implementations,
133+
_reference_implementation_postfix,
129134
_expected_failure_implementations,
130135
)

dpbench/config/reader.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from .implementation_postfix import Implementation
2020
from .module import Module
2121

22+
_REFERENCE_IMPLEMENTATIONS = {"numpy", "python"}
23+
2224

2325
def read_configs( # noqa: C901: TODO: move modules into config
2426
benchmarks: set[str] = None,
@@ -207,6 +209,7 @@ def read_frameworks(
207209
postfix
208210
for postfix in framework.postfixes
209211
if postfix.postfix in implementations
212+
or postfix.postfix in _REFERENCE_IMPLEMENTATIONS
210213
]
211214

212215
if len(framework.postfixes) == 0:
@@ -215,24 +218,6 @@ def read_frameworks(
215218
config.frameworks.append(framework)
216219

217220

218-
def read_implementation_postfixes(
219-
config: Config, impl_postfix_file: str
220-
) -> None:
221-
"""Read and populate implementation postfix configuration file.
222-
223-
Args:
224-
config: Configuration object where settings should be populated.
225-
impl_postfix_file: Path to the configuration file.
226-
"""
227-
with open(impl_postfix_file) as file:
228-
file_contents = file.read()
229-
230-
implementation_postfixes = tomli.loads(file_contents)
231-
for impl in implementation_postfixes["implementations"]:
232-
implementation = Implementation.from_dict(impl)
233-
config.implementations.append(implementation)
234-
235-
236221
def read_precision_dtypes(config: Config, precision_dtypes_file: str) -> None:
237222
"""Read and populate dtype_obj data types file.
238223
@@ -341,13 +326,15 @@ def read_benchmark_implementations(
341326
]
342327

343328
setup_init(config, modules)
329+
set_default_reference_implementation_postfix(config, modules)
344330

345331
for module in modules:
346332
module_name, postfix = discover_module_name_and_postfix(module, config)
347333

348334
if (
349335
implementations
350336
and (postfix not in implementations)
337+
and (postfix != config.reference_implementation_postfix)
351338
or (config.init and config.init.module_name.endswith(module_name))
352339
):
353340
continue
@@ -382,6 +369,34 @@ def read_benchmark_implementations(
382369
)
383370

384371

372+
def set_default_reference_implementation_postfix(
373+
config: Benchmark,
374+
modules: set[str] = None,
375+
):
376+
"""Sets reference implementation postfix if not set.
377+
378+
It will set it to 'numpy' or 'python' with priority to 'numpy' depending on
379+
the available modules.
380+
381+
Args:
382+
config: Benchmark configuration object where settings should be
383+
populated.
384+
modules: List of modules in benchmark implementation dir.
385+
"""
386+
if config.reference_implementation_postfix:
387+
return
388+
389+
postfixes = {
390+
discover_module_name_and_postfix(module, config)[1]
391+
for module in modules
392+
}
393+
394+
for postfix in _REFERENCE_IMPLEMENTATIONS:
395+
if postfix in postfixes:
396+
config.reference_implementation_postfix = postfix
397+
break
398+
399+
385400
def get_benchmark_index(configs: list[Benchmark], module_name: str) -> int:
386401
"""Finds configuration index by module name."""
387402
return next(

dpbench/infrastructure/benchmark.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -363,21 +363,19 @@ def _set_reference_implementation(self) -> BenchmarkImplFn:
363363
364364
"""
365365

366-
ref_impl = None
367-
368-
python_impl = [
369-
impl for impl in self.impl_fnlist if "python" in impl.name
366+
reference_implementations = [
367+
impl
368+
for impl in self.impl_fnlist
369+
if self.info.reference_implementation_postfix in impl.name
370370
]
371-
numpy_impl = [impl for impl in self.impl_fnlist if "numpy" in impl.name]
372371

373-
if numpy_impl:
374-
ref_impl = numpy_impl[0]
375-
elif python_impl:
376-
ref_impl = python_impl[0]
377-
else:
372+
print(self.impl_fnlist)
373+
print(reference_implementations)
374+
375+
if len(reference_implementations) == 0:
378376
raise RuntimeError("No reference implementation")
379377

380-
return ref_impl
378+
return reference_implementations[0]
381379

382380
def _set_impl_to_framework_map(self) -> dict[str, Framework]:
383381
"""Create a dictionary mapping each implementation function name to a

0 commit comments

Comments
 (0)