Skip to content

Commit 26adcf6

Browse files
committed
Fix multiprocessing
1 parent 7277cb6 commit 26adcf6

File tree

10 files changed

+505
-268
lines changed

10 files changed

+505
-268
lines changed

.github/workflows/build_and_run.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ jobs:
7171
7272
dpbench \
7373
-i python,numpy,dpnp,sycl,numba_n,numba_np,numba_npr,numba_dpex_k,numba_dpex_n,numba_dpex_p \
74-
run -r2 --no-print-results
74+
run --no-validate -r2 --no-print-results
7575
7676
- name: Generate report
7777
shell: bash -l {0}

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ repos:
5050
hooks:
5151
- id: pydocstyle
5252
# TODO: add packages one by one to enforce pydocstyle eventually
53-
files: (^dpbench/config/|^scripts/|^dpbench/console/)
53+
files: (^dpbench/config/|^scripts/|^dpbench/console/|^dpbench/infrastructure/benchmark_runner.py)
5454
args: ["--convention=google"]
5555
# D417 does not work properly:
5656
# https://github.com/PyCQA/pydocstyle/issues/459

dpbench/console/_namespace.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class Namespace(argparse.Namespace):
1212
"""Namespace class for parsed arguments."""
1313

1414
benchmarks: set[str]
15-
implementations: set[str]
15+
implementations: list[str]
1616
all_implementations: bool
1717
preset: str
1818
sycl_device: str
@@ -31,6 +31,7 @@ class Namespace(argparse.Namespace):
3131
program: str
3232
color: str
3333
comparisons: list[str]
34+
skip_expected_failures: bool
3435

3536

3637
class CommaSeparateStringAction(argparse.Action):

dpbench/console/entry.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
import argparse
88
from importlib.metadata import version
99

10-
from ._namespace import CommaSeparateStringAction, Namespace
10+
from ._namespace import (
11+
CommaSeparateStringAction,
12+
CommaSeparateStringListAction,
13+
Namespace,
14+
)
1115
from .config import add_config_arguments, execute_config
1216
from .report import add_report_arguments, execute_report
1317
from .run import add_run_arguments, execute_run
@@ -31,7 +35,7 @@ def parse_args() -> Namespace:
3135
"-i",
3236
"--implementations",
3337
type=str,
34-
action=CommaSeparateStringAction,
38+
action=CommaSeparateStringListAction,
3539
nargs="?",
3640
default={"python", "numpy"},
3741
help="Comma separated list of implementations. Use "
@@ -100,7 +104,8 @@ def main():
100104
"""Main function to run on dpbench console tool."""
101105
args = parse_args()
102106

103-
if args.program in {"run", "report"}:
107+
conn = None
108+
if args.program == "report" or args.program == "run" and args.save:
104109
import dpbench.infrastructure as dpbi
105110
from dpbench.infrastructure.reporter import update_run_id
106111

dpbench/console/run.py

Lines changed: 69 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,14 @@
55
"""Run subcommand package."""
66

77
import argparse
8+
import logging
89

910
import sqlalchemy
1011

12+
import dpbench.config as cfg
13+
import dpbench.infrastructure as dpbi
14+
from dpbench.infrastructure.benchmark_runner import BenchmarkRunner, RunConfig
15+
1116
from ._namespace import Namespace
1217

1318

@@ -94,6 +99,23 @@ def add_run_arguments(parser: argparse.ArgumentParser):
9499
default=None,
95100
help="Sycl device to overwrite for framework configurations.",
96101
)
102+
parser.add_argument(
103+
"--skip-expected-failures",
104+
action=argparse.BooleanOptionalAction,
105+
default=True,
106+
help="Either to save execution into database.",
107+
)
108+
109+
110+
def _find_framework_config(implementation: str) -> cfg.Framework:
111+
framework = None
112+
113+
for f in cfg.GLOBAL.frameworks:
114+
for impl in f.postfixes:
115+
if impl.postfix == implementation:
116+
framework = f
117+
118+
return framework
97119

98120

99121
def execute_run(args: Namespace, conn: sqlalchemy.Engine):
@@ -105,33 +127,63 @@ def execute_run(args: Namespace, conn: sqlalchemy.Engine):
105127
args: object with all input arguments.
106128
conn: database connection.
107129
"""
108-
import dpbench.config as cfg
109-
import dpbench.infrastructure as dpbi
110-
from dpbench.infrastructure.runner import run_benchmarks
111-
112130
cfg.GLOBAL = cfg.read_configs(
113131
benchmarks=args.benchmarks,
114-
implementations=args.implementations,
132+
implementations=set(args.implementations),
115133
no_dpbench=not args.dpbench,
116134
with_npbench=args.npbench,
117135
with_polybench=args.polybench,
118136
)
119137

138+
if args.all_implementations:
139+
args.implementations = {
140+
impl.postfix for impl in cfg.GLOBAL.implementations
141+
}
142+
120143
if args.sycl_device:
121144
for framework in cfg.GLOBAL.frameworks:
122145
framework.sycl_device = args.sycl_device
123146

124-
if args.run_id is None:
147+
if args.save and args.run_id is None:
125148
args.run_id = dpbi.create_run(conn)
126149

127-
run_benchmarks(
128-
conn=conn,
129-
preset=args.preset,
130-
repeat=args.repeat,
131-
validate=args.validate,
132-
timeout=args.timeout,
133-
precision=args.precision,
134-
print_results=args.print_results,
135-
run_id=args.run_id,
136-
implementations=list(args.implementations),
137-
)
150+
runner = BenchmarkRunner()
151+
152+
for benchmark in cfg.GLOBAL.benchmarks:
153+
print("")
154+
print(
155+
f"================ Benchmark {benchmark.name} ({benchmark.module_name}) ========================"
156+
)
157+
print("")
158+
159+
for implementation in args.implementations:
160+
framework = _find_framework_config(implementation)
161+
162+
if not framework:
163+
logging.error(
164+
f"Could not find framework for {implementation} implementation"
165+
)
166+
continue
167+
168+
logging.info(
169+
f"Running {benchmark.module_name} ({implementation}) on {framework.simple_name}"
170+
)
171+
172+
runner.run_benchmark_and_save(
173+
RunConfig(
174+
conn=conn,
175+
benchmark=benchmark,
176+
framework=framework,
177+
implementation=implementation,
178+
preset=args.preset,
179+
repeat=args.repeat,
180+
validate=args.validate,
181+
timeout=args.timeout,
182+
precision=args.precision,
183+
print_results=args.print_results,
184+
run_id=args.run_id,
185+
skip_expected_failures=args.skip_expected_failures,
186+
)
187+
)
188+
189+
runner.close_connections()

dpbench/infrastructure/benchmark.py

Lines changed: 2 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -366,12 +366,10 @@ def _set_reference_implementation(self) -> BenchmarkImplFn:
366366
reference_implementations = [
367367
impl
368368
for impl in self.impl_fnlist
369-
if self.info.reference_implementation_postfix in impl.name
369+
if self.info.reference_implementation_postfix
370+
and self.info.reference_implementation_postfix in impl.name
370371
]
371372

372-
print(self.impl_fnlist)
373-
print(reference_implementations)
374-
375373
if len(reference_implementations) == 0:
376374
raise RuntimeError("No reference implementation")
377375

@@ -714,67 +712,3 @@ def get_data_init(
714712
)
715713
}
716714
)
717-
718-
def run(
719-
self,
720-
implementation_postfix: str = None,
721-
preset: str = "S",
722-
repeat: int = 10,
723-
validate: bool = True,
724-
timeout: float = 200.0,
725-
precision: str = None,
726-
conn: sqlite3.Connection = None,
727-
run_id: int = None,
728-
) -> list[BenchmarkResults]:
729-
results: list[BenchmarkResults] = []
730-
731-
implementation_postfixes = []
732-
733-
if implementation_postfix:
734-
implementation_postfixes.append(implementation_postfix)
735-
else:
736-
for impl in self.impl_fnlist:
737-
impl_postfix = impl.name[
738-
(len(self.bname) - len(impl.name) + 1) : # noqa: E203
739-
]
740-
741-
implementation_postfixes.append(impl_postfix)
742-
743-
# TODO: do we call ref benchmark function twice?
744-
for implementation_postfix in implementation_postfixes:
745-
# copy_output is true only if validation is needed.
746-
runner = BenchmarkRunner(
747-
bench=self,
748-
impl_postfix=implementation_postfix,
749-
preset=preset,
750-
repeat=repeat,
751-
timeout=timeout,
752-
precision=precision,
753-
copy_output=validate,
754-
)
755-
result = runner.results
756-
if validate and result.error_state == ErrorCodes.SUCCESS:
757-
if self._validate_results(
758-
preset, runner.fmwrk, runner.output, precision
759-
):
760-
result.validation_state = ValidationStatusCodes.SUCCESS
761-
else:
762-
result.validation_state = ValidationStatusCodes.FAILURE
763-
result.error_state = ErrorCodes.FAILED_VALIDATION
764-
result.error_msg = "Validation failed"
765-
if conn:
766-
store_results(
767-
conn,
768-
result.Result(
769-
run_id=run_id,
770-
benchmark_name=self.bname,
771-
framework_version=runner.fmwrk.fname
772-
+ " "
773-
+ runner.fmwrk.version()
774-
if runner.fmwrk
775-
else "n/a",
776-
),
777-
)
778-
results.append((result, runner.fmwrk))
779-
780-
return results

0 commit comments

Comments
 (0)