|
| 1 | +import collections.abc |
| 2 | +from itertools import chain |
| 3 | + |
1 | 4 | from . import SyclDevice, get_devices
|
2 | 5 |
|
3 | 6 |
|
4 |
| -def select_device_with_aspects(aspect_list, deny_list=[]): |
5 |
| - check_list = aspect_list + deny_list |
6 |
| - for asp in check_list: |
| 7 | +def select_device_with_aspects(required_aspects, excluded_aspects=[]): |
| 8 | + """Selects the root :class:`dpctl.SyclDevice` that has the highest |
| 9 | + default selector score among devices that have all aspects in the |
| 10 | + `required_aspects` list, and do not have any aspects in `excluded_aspects` |
| 11 | + list. |
| 12 | +
|
| 13 | + Supported |
| 14 | +
|
| 15 | + :Example: |
| 16 | + .. code-block:: python |
| 17 | +
|
| 18 | + import dpctl |
| 19 | + # select a GPU that supports double precision |
| 20 | + dpctl.select_device_with_aspects(['fp64', 'gpu']) |
| 21 | + # select non-custom device with USM shared allocations |
| 22 | + dpctl.select_device_with_aspects( |
| 23 | + ['usm_shared_allocations'], excluded_aspects=['custom']) |
| 24 | + """ |
| 25 | + if isinstance(required_aspects, str): |
| 26 | + required_aspects = [required_aspects] |
| 27 | + if isinstance(excluded_aspects, str): |
| 28 | + excluded_aspects = [excluded_aspects] |
| 29 | + seq = collections.abc.Sequence |
| 30 | + input_types_ok = isinstance(required_aspects, seq) and isinstance( |
| 31 | + excluded_aspects, seq |
| 32 | + ) |
| 33 | + if not input_types_ok: |
| 34 | + raise TypeError( |
| 35 | + "Aspects are expected to be Python sequences, " |
| 36 | + "e.g. lists, of strings" |
| 37 | + ) |
| 38 | + for asp in chain(required_aspects, excluded_aspects): |
7 | 39 | if type(asp) != str:
|
8 | 40 | raise TypeError("The list objects must be of a string type")
|
9 | 41 | if not hasattr(SyclDevice, "has_aspect_" + asp):
|
10 |
| - raise ValueError(f"The {asp} aspect is not supported in dpctl") |
| 42 | + raise AttributeError(f"The {asp} aspect is not supported in dpctl") |
11 | 43 | devs = get_devices()
|
12 | 44 | max_score = 0
|
13 | 45 | selected_dev = None
|
14 | 46 |
|
15 | 47 | for dev in devs:
|
16 |
| - # aspect_status = True |
17 |
| - # for asp in aspect_list: |
18 |
| - # has_aspect = "dev.has_aspect_" + asp |
19 |
| - # if not eval(has_aspect): |
20 |
| - # aspect_status = False |
21 |
| - # for deny in deny_list: |
22 |
| - # has_aspect = "dev.has_aspect_" + deny |
23 |
| - # if eval(has_aspect): |
24 |
| - # aspect_status = False |
25 | 48 | aspect_status = all(
|
26 |
| - (getattr(dev, "has_aspect_" + asp) is True for asp in aspect_list) |
| 49 | + ( |
| 50 | + getattr(dev, "has_aspect_" + asp) is True |
| 51 | + for asp in required_aspects |
| 52 | + ) |
27 | 53 | )
|
28 | 54 | aspect_status = aspect_status and not (
|
29 | 55 | any(
|
30 |
| - (getattr(dev, "has_aspect_" + asp) is True for asp in deny_list) |
| 56 | + ( |
| 57 | + getattr(dev, "has_aspect_" + asp) is True |
| 58 | + for asp in excluded_aspects |
| 59 | + ) |
31 | 60 | )
|
32 | 61 | )
|
33 | 62 | if aspect_status and dev.default_selector_score > max_score:
|
34 | 63 | max_score = dev.default_selector_score
|
35 | 64 | selected_dev = dev
|
36 | 65 |
|
| 66 | + if selected_dev is None: |
| 67 | + raise ValueError( |
| 68 | + f"Requested device is unavailable: " |
| 69 | + f"required_aspects={required_aspects}, " |
| 70 | + f"excluded_aspects={excluded_aspects}" |
| 71 | + ) |
| 72 | + |
37 | 73 | return selected_dev
|
0 commit comments