Skip to content

Commit 4c4b34f

Browse files
Add select_device_with_aspects func
1 parent 8e716ae commit 4c4b34f

File tree

3 files changed

+81
-0
lines changed

3 files changed

+81
-0
lines changed

dpctl/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
set_global_queue,
6565
)
6666

67+
from ._device_selection import select_device_with_aspects
6768
from ._version import get_versions
6869
from .enum_types import backend_type, device_type, event_status_type
6970

@@ -80,6 +81,7 @@
8081
"select_default_device",
8182
"select_gpu_device",
8283
"select_host_device",
84+
"select_device_with_aspects",
8385
"get_num_devices",
8486
"has_cpu_devices",
8587
"has_gpu_devices",

dpctl/_device_selection.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from . import SyclDevice, get_devices
2+
3+
4+
def select_device_with_aspects(aspect_list, deny_list=[]):
5+
check_list = aspect_list + deny_list
6+
for asp in check_list:
7+
if type(asp) != str:
8+
raise TypeError("The list objects must be of a string type")
9+
if not hasattr(SyclDevice, "has_aspect_" + asp):
10+
raise ValueError(f"The {asp} aspect is not supported in dpctl")
11+
devs = get_devices()
12+
max_score = 0
13+
selected_dev = None
14+
15+
for dev in devs:
16+
aspect_status = True
17+
for asp in aspect_list:
18+
has_aspect = "dev.has_aspect_" + asp
19+
if not eval(has_aspect):
20+
aspect_status = False
21+
for deny in deny_list:
22+
has_aspect = "dev.has_aspect_" + deny
23+
if eval(has_aspect):
24+
aspect_status = False
25+
if aspect_status and dev.default_selector_score > max_score:
26+
max_score = dev.default_selector_score
27+
selected_dev = dev
28+
29+
return selected_dev

dpctl/tests/test_sycl_device.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,3 +645,53 @@ def test_hashing_of_device():
645645
"""
646646
device_dict = {dpctl.SyclDevice(): "default_device"}
647647
assert device_dict
648+
649+
650+
list_of_valid_aspects = [
651+
"cpu",
652+
"gpu",
653+
"accelerator",
654+
"custom",
655+
"fp16",
656+
"fp64",
657+
"image",
658+
"online_compiler",
659+
"online_linker",
660+
"queue_profiling",
661+
"usm_device_allocations",
662+
"usm_host_allocations",
663+
"usm_shared_allocations",
664+
"usm_system_allocator",
665+
]
666+
667+
list_of_invalid_aspects = [
668+
"emulated",
669+
"host_debuggable",
670+
"atomic64",
671+
"usm_atomic_host_allocations",
672+
"usm_atomic_shared_allocations",
673+
]
674+
675+
676+
@pytest.fixture(params=list_of_valid_aspects)
677+
def valid_aspects(request):
678+
return request.param
679+
680+
681+
@pytest.fixture(params=list_of_invalid_aspects)
682+
def invalid_aspects(request):
683+
return request.param
684+
685+
686+
def test_valid_aspects(valid_aspects):
687+
dpctl.select_device_with_aspects([valid_aspects])
688+
689+
690+
def test_invalid_aspects(invalid_aspects):
691+
try:
692+
dpctl.select_device_with_aspects([invalid_aspects])
693+
raise AttributeError(
694+
f"The {invalid_aspects} aspect is supported in dpctl"
695+
)
696+
except ValueError:
697+
pytest.skip(f"The {invalid_aspects} aspect is not supported in dpctl")

0 commit comments

Comments
 (0)