Skip to content

Commit 4a48b84

Browse files
rsudermankeshavvinayak01
authored andcommitted
Add support for ml_dtypes to python runtime bindings (iree-org#21549)
Python runtime bindings should include support for atypical dtypes provided by `ml_dtypes`. We also drop the requirement for `PYBUF_FORMAT` as we never look at the format structure so this causes failures when handling the atypical dtypes. --------- Signed-off-by: Rob Suderman <[email protected]> Signed-off-by: keshavvinayak01 <[email protected]>
1 parent 9b34c7b commit 4a48b84

File tree

6 files changed

+15
-4
lines changed

6 files changed

+15
-4
lines changed

runtime/bindings/python/hal.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ py::object HalAllocator::AllocateBufferCopy(
180180
// only request that via PyBUF_ND. Long term, we should consult an
181181
// "oracle" in the runtime to determine the precise required format
182182
// and set flags accordingly (and fallback/copy on failure).
183-
int flags = PyBUF_FORMAT | PyBUF_ND;
183+
int flags = PyBUF_C_CONTIGUOUS | PyBUF_ND;
184184

185185
// Acquire the backing buffer and setup RAII release.
186186
if (PyObject_GetBuffer(buffer.ptr(), &py_view, flags) != 0) {
@@ -243,7 +243,7 @@ HalBuffer HalAllocator::AllocateHostStagingBufferCopy(HalDevice& device,
243243
// only request that via PyBUF_ND. Long term, we should consult an
244244
// "oracle" in the runtime to determine the precise required format
245245
// and set flags accordingly (and fallback/copy on failure).
246-
int flags = PyBUF_FORMAT | PyBUF_ND;
246+
int flags = PyBUF_C_CONTIGUOUS | PyBUF_ND;
247247

248248
// Acquire the backing buffer and setup RAII release.
249249
if (PyObject_GetBuffer(buffer.ptr(), &py_view, flags) != 0) {
@@ -2025,7 +2025,7 @@ void SetupHalBindings(nanobind::module_ m) {
20252025
py::handle pattern, iree_device_size_t target_offset,
20262026
std::optional<iree_device_size_t> length, bool end) {
20272027
Py_buffer pattern_view;
2028-
int flags = PyBUF_FORMAT | PyBUF_ND;
2028+
int flags = PyBUF_C_CONTIGUOUS | PyBUF_ND;
20292029
if (PyObject_GetBuffer(pattern.ptr(), &pattern_view, flags) != 0) {
20302030
// The GetBuffer call is required to set an appropriate error.
20312031
throw py::python_error();

runtime/bindings/python/iree/runtime/_binding.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,10 +299,11 @@ class HalElementType:
299299
BOOL_8: ClassVar[HalElementType] = ...
300300
COMPLEX_128: ClassVar[HalElementType] = ...
301301
COMPLEX_64: ClassVar[HalElementType] = ...
302-
FLOAT_8_E4M3: ClassVar[HalElementType] = ...
302+
FLOAT_8_E4M3_FN: ClassVar[HalElementType] = ...
303303
FLOAT_8_E4M3_FNUZ: ClassVar[HalElementType] = ...
304304
FLOAT_8_E5M2: ClassVar[HalElementType] = ...
305305
FLOAT_8_E5M2_FNUZ: ClassVar[HalElementType] = ...
306+
FLOAT_8_E8M0_FNU: ClassVar[HalElementType] = ...
306307
FLOAT_16: ClassVar[HalElementType] = ...
307308
FLOAT_32: ClassVar[HalElementType] = ...
308309
FLOAT_64: ClassVar[HalElementType] = ...

runtime/bindings/python/iree/runtime/array_interop.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from typing import Optional, Tuple
99
import logging
10+
import ml_dtypes
1011
import numpy as np
1112
import numpy.lib.mixins
1213

@@ -305,6 +306,12 @@ def asdevicearray(
305306
(np.bool_, HalElementType.BOOL_8),
306307
(np.complex64, HalElementType.COMPLEX_64),
307308
(np.complex128, HalElementType.COMPLEX_128),
309+
(ml_dtypes.bfloat16, HalElementType.BFLOAT_16),
310+
(ml_dtypes.float8_e4m3fn, HalElementType.FLOAT_8_E4M3_FN),
311+
(ml_dtypes.float8_e4m3fnuz, HalElementType.FLOAT_8_E4M3_FNUZ),
312+
(ml_dtypes.float8_e5m2, HalElementType.FLOAT_8_E5M2),
313+
(ml_dtypes.float8_e5m2fnuz, HalElementType.FLOAT_8_E5M2_FNUZ),
314+
(ml_dtypes.float8_e8m0fnu, HalElementType.FLOAT_8_E8M0_FNU),
308315
)
309316

310317

runtime/bindings/python/iree/runtime/build_requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ numpy>=2.0.0b1
99
requests>=2.28.0
1010
wheel>=0.36.2
1111
sympy==1.12.1
12+
ml_dtypes>=0.5.1

runtime/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ requires = [
66
"ninja",
77
"numpy>=2.0.0b1",
88
"packaging",
9+
"ml_dtypes>=0.5.1",
910
]
1011
build-backend = "setuptools.build_meta"

runtime/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,5 +672,6 @@ def populate_built_package(abs_dir):
672672
},
673673
install_requires=[
674674
"numpy",
675+
"ml_dtypes",
675676
],
676677
)

0 commit comments

Comments
 (0)