Skip to content

Commit a1f48ac

Browse files
committed
FEAT: Adding functions to query device and active backends
1 parent a9be255 commit a1f48ac

File tree

1 file changed

+39
-18
lines changed

1 file changed

+39
-18
lines changed

arrayfire/library.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class ERR(_Enum):
4545
TYPE = _Enum_Type(204)
4646
DIFF_TYPE = _Enum_Type(205)
4747
BATCH = _Enum_Type(207)
48+
DEVICE = _Enum_Type(208)
4849

4950
# 300-399 Errors for missing software features
5051
NOT_SUPPORTED = _Enum_Type(301)
@@ -524,12 +525,9 @@ def get_backend_id(A):
524525
name : str.
525526
Backend name
526527
"""
527-
if (backend.is_unified()):
528-
backend_id = ct.c_int(BACKEND.DEFAULT.value)
529-
safe_call(backend.get().af_get_backend_id(ct.pointer(backend_id), A.arr))
530-
return backend.get_name(backend_id.value)
531-
else:
532-
return backend.name()
528+
backend_id = ct.c_int(BACKEND.CPU.value)
529+
safe_call(backend.get().af_get_backend_id(ct.pointer(backend_id), A.arr))
530+
return backend.get_name(backend_id.value)
533531

534532
def get_backend_count():
535533
"""
@@ -541,12 +539,9 @@ def get_backend_count():
541539
count : int
542540
Number of available backends
543541
"""
544-
if (backend.is_unified()):
545-
count = ct.c_int(0)
546-
safe_call(backend.get().af_get_backend_count(ct.pointer(count)))
547-
return count.value
548-
else:
549-
return 1
542+
count = ct.c_int(0)
543+
safe_call(backend.get().af_get_backend_count(ct.pointer(count)))
544+
return count.value
550545

551546
def get_available_backends():
552547
"""
@@ -558,11 +553,37 @@ def get_available_backends():
558553
names : tuple of strings
559554
Names of available backends
560555
"""
561-
if (backend.is_unified()):
562-
available = ct.c_int(0)
563-
safe_call(backend.get().af_get_available_backends(ct.pointer(available)))
564-
return backend.parse(int(available.value))
565-
else:
566-
return (backend.name(),)
556+
available = ct.c_int(0)
557+
safe_call(backend.get().af_get_available_backends(ct.pointer(available)))
558+
return backend.parse(int(available.value))
559+
560+
def get_active_backend():
561+
"""
562+
Get the current active backend
563+
564+
name : str.
565+
Backend name
566+
"""
567+
backend_id = ct.c_int(BACKEND.CPU.value)
568+
safe_call(backend.get().af_get_active_backend(ct.pointer(backend_id)))
569+
return backend.get_name(backend_id.value)
570+
571+
def get_device_id(A):
572+
"""
573+
Get the device id of the array
574+
575+
Parameters
576+
----------
577+
A : af.Array
578+
579+
Returns
580+
----------
581+
582+
dev : Integer
583+
id of the device array was created on
584+
"""
585+
device_id = ct.c_int(0)
586+
safe_call(backend.get().af_get_device_id(ct.pointer(device_id), A.arr))
587+
return device_id
567588

568589
from .util import safe_call

0 commit comments

Comments
 (0)