Skip to content

Commit 14bf63f

Browse files
author
Diptorup Deb
committed
Add helper functions to get current backend.
1 parent c96c50d commit 14bf63f

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

dpctl/sycl_core.pyx

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,6 @@ cdef class _SyclRTManager:
481481
'''
482482
cdef dict _backend_ty_dict
483483
cdef dict _device_ty_dict
484-
cdef SyclQueue current_queue
485484

486485
def __cinit__ (self):
487486

@@ -518,6 +517,16 @@ cdef class _SyclRTManager:
518517
'''
519518
DPPLPlatform_DumpInfo()
520519

520+
def print_available_backends (self):
521+
""" Prints the available backends.
522+
"""
523+
print(self._backend_ty_dict.keys())
524+
525+
def get_current_backend (self):
526+
""" Returns the backend for the current queue as `backend_type` enum
527+
"""
528+
return self.get_current_queue().get_sycl_backend()
529+
521530
def get_current_device_type (self):
522531
''' Returns current device type as `device_type` enum
523532
'''
@@ -613,9 +622,9 @@ def create_program_from_source (SyclQueue q, unicode source, unicode copts=""):
613622
copts (unicode) : Optional compilation flags that will be used
614623
when compiling the program.
615624
616-
Returns:
617-
program (SyclProgram): A SyclProgram object wrapping the
618-
syc::program returned by the C API.
625+
Returns:
626+
program (SyclProgram): A SyclProgram object wrapping the
627+
syc::program returned by the C API.
619628
'''
620629

621630
cdef DPPLSyclProgramRef Pref
@@ -674,9 +683,14 @@ def device_context (str queue_str="opencl:gpu:0"):
674683
# If set_context is unable to create a new context an exception is raised.
675684
try:
676685
attrs = queue_str.split(':')
677-
if len(attrs) != 3:
686+
nattrs = len(attrs)
687+
if (nattrs < 2 or nattrs > 3):
678688
raise ValueError("Invalid device context string. Should be "
679-
" backend:device:device_number")
689+
"backend:device:device_number or "
690+
"backend:device. In the later case the "
691+
"device_number defaults to 0")
692+
if nattrs == 2:
693+
attrs.append("0")
680694
ctxt = None
681695
ctxt = _mgr._set_as_current_queue(attrs[0], attrs[1], int(attrs[2]))
682696
yield ctxt

0 commit comments

Comments
 (0)