Skip to content

Commit 81b362d

Browse files
dpctl.tensor.Device uses global cache instead of cache stored in the class
1 parent a4125c6 commit 81b362d

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

dpctl/tensor/_device.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
import dpctl
17+
from dpctl._sycl_queue_manager import get_device_cached_queue
1718

1819
__doc__ = "Implementation of array API mandated Device class"
1920

@@ -60,9 +61,7 @@ def create_device(cls, dev):
6061
elif isinstance(dev, dpctl.SyclDevice):
6162
par = dev.parent_device
6263
if par is None:
63-
if dev not in cls.__device_queue_map__:
64-
cls.__device_queue_map__[dev] = dpctl.SyclQueue(dev)
65-
obj.sycl_queue_ = cls.__device_queue_map__[dev]
64+
obj.sycl_queue_ = get_device_cached_queue(dev)
6665
else:
6766
raise ValueError(
6867
f"Using non-root device {dev} to specify offloading "
@@ -74,9 +73,7 @@ def create_device(cls, dev):
7473
_dev = dpctl.SyclDevice()
7574
else:
7675
_dev = dpctl.SyclDevice(dev)
77-
if _dev not in cls.__device_queue_map__:
78-
cls.__device_queue_map__[_dev] = dpctl.SyclQueue(_dev)
79-
obj.sycl_queue_ = cls.__device_queue_map__[_dev]
76+
obj.sycl_queue_ = get_device_cached_queue(_dev)
8077
return obj
8178

8279
@property

0 commit comments

Comments
 (0)