Skip to content

Commit 583ee26

Browse files
committed
Add required packages to frameworks
1 parent ee1357c commit 583ee26

File tree

9 files changed

+57
-2
lines changed

9 files changed

+57
-2
lines changed

dpbench/config/reader.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
import tomli
1515

16+
from dpbench.infrastructure.frameworks.fabric import get_framework_class
17+
1618
from .benchmark import Benchmark, BenchmarkImplementation, Presets
1719
from .config import Config
1820
from .framework import Framework
@@ -117,7 +119,7 @@ def read_configs( # noqa: C901: TODO: move modules into config
117119
for framework in config.frameworks:
118120
config.implementations += framework.postfixes
119121

120-
if implementations is None:
122+
if implementations is None or len(implementations) == 0:
121123
implementations = {impl.postfix for impl in config.implementations}
122124

123125
if load_implementations:
@@ -228,6 +230,15 @@ def read_frameworks(
228230
if len(framework.postfixes) == 0:
229231
continue
230232

233+
cls = get_framework_class(framework)
234+
unavailable_pkgs = cls.get_missing_required_packages()
235+
if len(unavailable_pkgs) > 0:
236+
logging.warning(
237+
f"Framework {framework.simple_name} unavailable "
238+
+ f"due to missing packages {unavailable_pkgs}"
239+
)
240+
continue
241+
231242
config.frameworks.append(framework)
232243

233244

dpbench/infrastructure/frameworks/cupy_framework.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ def __init__(self, fname: str = None, config: cfg.Framework = None):
1919

2020
super().__init__(fname, config)
2121

22+
@staticmethod
23+
def required_packages() -> list[str]:
24+
return ["cupy"]
25+
2226
def copy_to_func(self) -> Callable:
2327
"""Returns the copy-method that should be used
2428
for copying the benchmark arguments."""

dpbench/infrastructure/frameworks/dpnp_framework.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ def __init__(self, fname: str = None, config: cfg.Framework = None):
3535
)
3636
raise sdce
3737

38+
@staticmethod
39+
def required_packages() -> list[str]:
40+
return ["dpnp"]
41+
3842
def device_filter_string(self) -> str:
3943
"""Returns the sycl device's filter string if the framework has an
4044
associated sycl device."""

dpbench/infrastructure/frameworks/fabric.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def build_framework_map() -> dict[str, Framework]:
3939
return result
4040

4141

42-
def build_framework(framework_config: cfg.Framework) -> Framework:
42+
def get_framework_class(framework_config: cfg.Framework) -> Framework:
4343
available_classes = [
4444
Framework,
4545
DpcppFramework,
@@ -61,4 +61,9 @@ def build_framework(framework_config: cfg.Framework) -> Framework:
6161
)
6262
constructor = Framework
6363

64+
return constructor
65+
66+
67+
def build_framework(framework_config: cfg.Framework) -> Framework:
68+
constructor = get_framework_class(framework_config)
6469
return constructor(config=framework_config)

dpbench/infrastructure/frameworks/framework.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# SPDX-License-Identifier: BSD-3-Clause
66

77
import logging
8+
from importlib.util import find_spec
89
from typing import Any, Callable, Dict, final
910

1011
import pkg_resources
@@ -45,6 +46,20 @@ def __init__(
4546

4647
self.device_info = cpuinfo.get_cpu_info().get("brand_raw")
4748

49+
@staticmethod
50+
def required_packages() -> list[str]:
51+
return []
52+
53+
@classmethod
54+
def get_missing_required_packages(cls) -> None:
55+
unavailable_packages = []
56+
for pkg in cls.required_packages():
57+
spec = find_spec(pkg)
58+
if spec is None:
59+
unavailable_packages.append(pkg)
60+
61+
return unavailable_packages
62+
4863
def device_filter_string(self) -> str:
4964
"""Returns the sycl device's filter string if the framework has an
5065
associated sycl device."""

dpbench/infrastructure/frameworks/numba_cuda_framework.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ def __init__(self, fname: str = None, config: cfg.Framework = None):
1919

2020
super().__init__(fname, config)
2121

22+
@staticmethod
23+
def required_packages() -> list[str]:
24+
return ["cupy"]
25+
2226
def copy_to_func(self) -> Callable:
2327
"""Returns the copy-method that should be used
2428
for copying the benchmark arguments."""

dpbench/infrastructure/frameworks/numba_dpex_framework.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ def __init__(self, fname: str = None, config: cfg.Framework = None):
3333
)
3434
raise sdce
3535

36+
@staticmethod
37+
def required_packages() -> list[str]:
38+
return ["numba_dpex"]
39+
3640
def device_filter_string(self) -> str:
3741
"""Returns the sycl device's filter string if the framework has an
3842
associated sycl device."""

dpbench/infrastructure/frameworks/numba_framework.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,7 @@ def __init__(self, fname: str = None, config: cfg.Framework = None):
2727
"""
2828

2929
super().__init__(fname, config)
30+
31+
@staticmethod
32+
def required_packages() -> list[str]:
33+
return ["numba"]

dpbench/infrastructure/frameworks/numba_mlir_framework.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ def __init__(self, fname: str = None, config: cfg.Framework = None):
2828

2929
self.device_info = dpctl.SyclDevice(self.sycl_device).name
3030

31+
@staticmethod
32+
def required_packages() -> list[str]:
33+
return ["numba_mlir"]
34+
3135
def copy_to_func(self) -> Callable:
3236
"""Returns the copy-method that should be used
3337
for copying the benchmark arguments to device."""

0 commit comments

Comments
 (0)