Skip to content

DLPack conversion does not work #40

@ZzEeKkAa

Description

@ZzEeKkAa

System: WSL2 Ubuntu 22.04, on top of Windows 11
CPU: 1270P
GPU: integrated ([ext_oneapi_level_zero:gpu:0] Intel(R) Level-Zero, Intel(R) Graphics [0x46a6] 1.3 [1.3.26032])
Tensorflow: 2.12.0
Jax: 0.4.4

import jax.numpy as jnp
jnp.arange(10, dtype=jnp.float32).__dlpack__()

Results into error:

Python 3.10.6 (main, May 29 2023, 11:10:38) [GCC 11.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax.numpy as jnp
>>> jnp.arange(10, dtype=jnp.float32).__dlpack__()
2023-07-06 21:56:46.726166: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:169] XLA service 0x558bc41f0c80 initialized for platform Interpreter (this does not guarantee that XLA will be used). Devices:
2023-07-06 21:56:46.726195: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:177]   StreamExecutor device (0): Interpreter, <undefined>
2023-07-06 21:56:46.732904: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc:215] TfrtCpuClient created.
2023-07-06 21:56:46.733628: I external/org_tensorflow/tensorflow/compiler/xla/stream_executor/tpu/tpu_initializer_helper.cc:266] Libtpu path is: libtpu.so
2023-07-06 21:56:46.733881: I external/org_tensorflow/tensorflow/compiler/xla/stream_executor/tpu/tpu_platform_interface.cc:73] No TPU platform found.
2023-07-06 21:56:46.993558: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_api.cc:85] GetPjrtApi was found for xpu at /home/yevhenii/Projects/users.yevhenii/examples/jax/libitex_xla_extension.so
2023-07-06 21:56:46.993610: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_api.cc:58] PJRT_Api is set for device type xpu
2023-07-06 21:56:46.994083: I itex/core/devices/gpu/itex_gpu_runtime.cc:129] Selected platform: Intel(R) Level-Zero
2023-07-06 21:56:46.994362: I itex/core/devices/gpu/itex_gpu_runtime.cc:154] number of sub-devices is zero, expose root device.
2023-07-06 21:56:46.998761: I itex/core/compiler/xla/service/service.cc:176] XLA service 0x558bc64650a0 initialized for platform sycl (this does not guarantee that XLA will be used). Devices:
2023-07-06 21:56:46.998790: I itex/core/compiler/xla/service/service.cc:184]   StreamExecutor device (0): <undefined>, <undefined>
2023-07-06 21:56:47.000500: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc:83] PjRtCApiClient created.
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/yevhenii/.local/share/virtualenvs/numba-dpex-x1V09ZPr/lib/python3.10/site-packages/jax/_src/array.py", line 343, in __dlpack__
    return to_dlpack(self)
  File "/home/yevhenii/.local/share/virtualenvs/numba-dpex-x1V09ZPr/lib/python3.10/site-packages/jax/_src/dlpack.py", line 51, in to_dlpack
    return xla_client._xla.buffer_to_dlpack_managed_tensor(
jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: PJRT C API does not support AcquireExternalReference

Jax itself works on level zero GPU, so environment is not broken. I guess it is lack of implementation of PJRT C API does not support AcquireExternalReference. It blocks from users workflows that require both jax related operations and, for example, numba_dpex related operations without memory copying.

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions