Skip to content

Commit 768d9bc

Browse files
AlexanderKalistratovZzEeKkAa
authored andcommitted
Add possibility to add custom validation function for workload
1 parent 8a77abb commit 768d9bc

File tree

7 files changed

+107
-30
lines changed

7 files changed

+107
-30
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# SPDX-FileCopyrightText: 2022 - 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
from dpbench.infrastructure.benchmark_validation import (
5+
validate as default_validate,
6+
)
7+
8+
9+
def validate(expected: dict[str, any], actual: dict[str, any], rel_error=1e-05):
10+
# TODO implement actual validation suitable for pca workload
11+
return default_validate(expected, actual, rel_error)

dpbench/config/benchmark.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ class Benchmark:
8080
module_name: str = ""
8181
package_path: str = ""
8282
func_name: str = ""
83+
validate_package_path: str = ""
84+
validate_func_name: str = ""
8385
kind: str = ""
8486
domain: str = ""
8587
parameters: Presets = field(default_factory=Presets)
@@ -100,6 +102,8 @@ def from_dict(obj: Any) -> "Benchmark":
100102
_module_name = str(obj.get("module_name") or "")
101103
_package_path = str(obj.get("package_path") or "")
102104
_func_name = str(obj.get("func_name") or "")
105+
_validate_package_path = str(obj.get("validate_package_path") or "")
106+
_validate_func_name = str(obj.get("validate_func_name") or "validate")
103107
_kind = str(obj.get("kind") or "")
104108
_domain = str(obj.get("domain") or "")
105109
_parameters = Presets(obj.get("parameters"))
@@ -122,6 +126,8 @@ def from_dict(obj: Any) -> "Benchmark":
122126
_module_name,
123127
_package_path,
124128
_func_name,
129+
_validate_package_path,
130+
_validate_func_name,
125131
_kind,
126132
_domain,
127133
_parameters,

dpbench/config/reader.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,7 @@ def read_benchmark_implementations(
340340

341341
setup_init(config, modules)
342342
set_default_reference_implementation_postfix(config, modules)
343+
set_validate_func(config, modules)
343344

344345
for module in modules:
345346
module_name, postfix = discover_module_name_and_postfix(module, config)
@@ -386,6 +387,49 @@ def read_benchmark_implementations(
386387
)
387388

388389

390+
def set_validate_func(
391+
config: Benchmark,
392+
modules: set[str] = None,
393+
):
394+
"""Read, discover and populate config with validation module and function.
395+
396+
Validation package priority, if found/set:
397+
1. package specified in config
398+
2. validation package at <benchmark>/<benchmark>_validate.py
399+
3. default validation package
400+
401+
Args:
402+
config: Benchmark configuration object where settings should be
403+
populated.
404+
modules: List of available modules for the benchmark to find init.
405+
"""
406+
if config.validate_package_path != "":
407+
if importlib.util.find_spec(config.validate_package_path) is None:
408+
logging.fatal(
409+
f"validation package path is specified but not found for {config.module_name}"
410+
)
411+
else:
412+
validate_package_path = "dpbench.infrastructure.benchmark_validation"
413+
414+
for module_name in [
415+
config.short_name + "_validate",
416+
config.module_name + "_validate",
417+
]:
418+
if module_name in modules:
419+
validate_package_path = config.package_path + "." + module_name
420+
break
421+
422+
config.validate_package_path = validate_package_path
423+
424+
val_mod = importlib.import_module(config.validate_package_path)
425+
426+
if not hasattr(val_mod, config.validate_func_name):
427+
logging.fatal(
428+
f"validation function '{config.validate_func_name}' not found for "
429+
+ f"{config.module_name} at '{validate_package_path}'"
430+
)
431+
432+
389433
def set_default_reference_implementation_postfix(
390434
config: Benchmark,
391435
modules: set[str] = None,

dpbench/configs/framework_info/dpcpp.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ postfix = "dpcpp"
1010
class = "DpcppFramework"
1111
arch = "cpu"
1212
sycl_device = "cpu"
13-
dpcpp_version = "IntelLLVM 2023.2.0"
13+
dpcpp_version = "IntelLLVM 2024.0.0"
1414

1515
[[framework.postfixes]]
1616
impl_postfix = "sycl"

dpbench/infrastructure/benchmark.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ def init_mod_path(self):
5252
def init_fn_name(self):
5353
return self.info.init.func_name if self.info.init else None
5454

55+
def get_validation_func(self):
56+
mod = importlib.import_module(self.info.validate_package_path)
57+
validate_function = getattr(mod, self.info.validate_func_name)
58+
59+
return validate_function
60+
5561
def get_implementation(self, implementation_postfix: str):
5662
implementation = None
5763

dpbench/infrastructure/benchmark_runner.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import dpbench.config as cfg
1616
from dpbench.infrastructure.benchmark import Benchmark
1717
from dpbench.infrastructure.benchmark_results import BenchmarkResults
18-
from dpbench.infrastructure.benchmark_validation import validate_results
1918
from dpbench.infrastructure.datamodel import store_results
2019
from dpbench.infrastructure.enums import ErrorCodes, ValidationStatusCodes
2120
from dpbench.infrastructure.frameworks import Framework
@@ -263,15 +262,24 @@ def run_benchmark(
263262

264263
if rc.validate and results.error_state == ErrorCodes.SUCCESS:
265264
ref_framework = build_framework(rc.ref_framework)
265+
# TODO: don't run it for every framework, but run it only once.
266266
ref_output = _exec_simple(
267267
bench,
268268
ref_framework,
269269
rc.benchmark.reference_implementation_postfix,
270270
rc.preset,
271271
)
272-
if validate_results(ref_output, output):
273-
results.validation_state = ValidationStatusCodes.SUCCESS
274-
else:
272+
273+
if ref_output:
274+
try:
275+
results.validation_state = ValidationStatusCodes.SUCCESS
276+
validate = bench.get_validation_func()
277+
validated = validate(ref_output, output)
278+
except Exception as e:
279+
logging.error(f"Exception during validation {e.args}")
280+
validated = False
281+
282+
if not ref_output or not validated:
275283
results.validation_state = ValidationStatusCodes.FAILURE
276284
results.error_state = ErrorCodes.FAILED_VALIDATION
277285
results.error_msg = "Validation failed"

dpbench/infrastructure/benchmark_validation.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
import numpy as np
1212

1313

14-
def validate_results(
15-
expected: dict[str, any], actual: dict[str, any], rel_error=1e-05
16-
) -> bool:
17-
"""Checks if expected equals actual with certain precision.
14+
def validate(
15+
expected: dict[str, any],
16+
actual: dict[str, any],
17+
rel_error=1e-05,
18+
):
19+
"""Default validation function.
1820
1921
Args:
2022
expected: expected values.
@@ -23,25 +25,20 @@ def validate_results(
2325
2426
Returns: true, if provided data is equal.
2527
"""
26-
if not expected:
27-
return False
28-
29-
try:
30-
for key in expected.keys():
31-
valid = validate_two_lists_of_array(
32-
expected[key], actual[key], rel_error=rel_error
28+
valid = True
29+
for key in expected.keys():
30+
valid = valid and validate_two_lists_of_array(
31+
expected[key], actual[key], rel_error=rel_error
32+
)
33+
if not valid:
34+
logging.error(
35+
(
36+
"Output did not match for {0}. "
37+
+ "Expected: {1} Actual: {2}"
38+
).format(key, expected[key], actual[key])
3339
)
34-
if not valid:
35-
logging.error(
36-
(
37-
"Output did not match for {0}. "
38-
+ "Expected: {1} Actual: {2}"
39-
).format(key, expected[key], actual[key])
40-
)
41-
return valid
42-
except Exception as e:
43-
logging.error(f"Exception during validation {e.args}")
44-
return False
40+
41+
return valid
4542

4643

4744
def validate_two_lists_of_array(
@@ -93,6 +90,11 @@ def relative_error(
9390
9491
Returns: relative error.
9592
"""
96-
if np.linalg.norm(ref) == 0.0:
97-
return 0.0
98-
return np.linalg.norm(ref - val) / np.linalg.norm(ref)
93+
ref_norm = np.linalg.norm(ref)
94+
if ref_norm:
95+
val_norm = np.linalg.norm(val)
96+
if val_norm == 0:
97+
return 0.0
98+
ref_norm = val_norm
99+
100+
return np.linalg.norm(ref - val) / ref_norm

0 commit comments

Comments
 (0)