-
Notifications
You must be signed in to change notification settings - Fork 45
Open
Description
System: WSL2 Ubuntu 22.04, on top of Windows 11
CPU: 1270P
GPU: integrated ([ext_oneapi_level_zero:gpu:0] Intel(R) Level-Zero, Intel(R) Graphics [0x46a6] 1.3 [1.3.26032])
Tensorflow: 2.12.0
Jax: 0.4.4
import jax.numpy as jnp
jnp.arange(10, dtype=jnp.float32).__dlpack__()Results into error:
Python 3.10.6 (main, May 29 2023, 11:10:38) [GCC 11.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax.numpy as jnp
>>> jnp.arange(10, dtype=jnp.float32).__dlpack__()
2023-07-06 21:56:46.726166: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:169] XLA service 0x558bc41f0c80 initialized for platform Interpreter (this does not guarantee that XLA will be used). Devices:
2023-07-06 21:56:46.726195: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:177] StreamExecutor device (0): Interpreter, <undefined>
2023-07-06 21:56:46.732904: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc:215] TfrtCpuClient created.
2023-07-06 21:56:46.733628: I external/org_tensorflow/tensorflow/compiler/xla/stream_executor/tpu/tpu_initializer_helper.cc:266] Libtpu path is: libtpu.so
2023-07-06 21:56:46.733881: I external/org_tensorflow/tensorflow/compiler/xla/stream_executor/tpu/tpu_platform_interface.cc:73] No TPU platform found.
2023-07-06 21:56:46.993558: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_api.cc:85] GetPjrtApi was found for xpu at /home/yevhenii/Projects/users.yevhenii/examples/jax/libitex_xla_extension.so
2023-07-06 21:56:46.993610: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_api.cc:58] PJRT_Api is set for device type xpu
2023-07-06 21:56:46.994083: I itex/core/devices/gpu/itex_gpu_runtime.cc:129] Selected platform: Intel(R) Level-Zero
2023-07-06 21:56:46.994362: I itex/core/devices/gpu/itex_gpu_runtime.cc:154] number of sub-devices is zero, expose root device.
2023-07-06 21:56:46.998761: I itex/core/compiler/xla/service/service.cc:176] XLA service 0x558bc64650a0 initialized for platform sycl (this does not guarantee that XLA will be used). Devices:
2023-07-06 21:56:46.998790: I itex/core/compiler/xla/service/service.cc:184] StreamExecutor device (0): <undefined>, <undefined>
2023-07-06 21:56:47.000500: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc:83] PjRtCApiClient created.
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/yevhenii/.local/share/virtualenvs/numba-dpex-x1V09ZPr/lib/python3.10/site-packages/jax/_src/array.py", line 343, in __dlpack__
return to_dlpack(self)
File "/home/yevhenii/.local/share/virtualenvs/numba-dpex-x1V09ZPr/lib/python3.10/site-packages/jax/_src/dlpack.py", line 51, in to_dlpack
return xla_client._xla.buffer_to_dlpack_managed_tensor(
jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: PJRT C API does not support AcquireExternalReference
Jax itself works on level zero GPU, so environment is not broken. I guess it is lack of implementation of PJRT C API does not support AcquireExternalReference. It blocks from users workflows that require both jax related operations and, for example, numba_dpex related operations without memory copying.
Metadata
Metadata
Assignees
Labels
No labels