Skip to content

Commit 08e261e

Browse files
Tensorstore Teamcopybara-github
authored andcommitted
Add Float4e2m1fn support in tensorstore.
PiperOrigin-RevId: 836265256 Change-Id: Id4c0848b7207420cd8fbe2f6856047cfc256b5d3
1 parent 367ef7d commit 08e261e

21 files changed

+673
-10
lines changed

docs/tensorstore_schema.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ definitions:
115115
title: |
116116
`8-bit floating point - Exponent: 5, Mantissa: 2, bias: 16, with NaN, without infinities.
117117
<https://github.com/jax-ml/ml_dtypes#float8_e5m2fnuz>`__ .
118+
- const: "float4_e2m1fn"
119+
title: |
120+
`4-bit floating point - Exponent: 2, Mantissa: 1, bias: 1, without NaN and infinities.
121+
<https://github.com/jax-ml/ml_dtypes#float4_e2m1fn>`__ .
118122
- const: "float16"
119123
title: |
120124
`IEEE 754 binary16

python/tensorstore/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,15 @@ class Indexable(metaclass=abc.ABCMeta):
257257
Data types
258258
"""
259259

260+
float4_e2m1fn: dtype
261+
"""8-bit floating-point data type.
262+
263+
Details in https://github.com/jax-ml/ml_dtypes#float4_e2m1fn
264+
265+
Group:
266+
Data types
267+
"""
268+
260269
float16: dtype
261270
""":wikipedia:`IEEE 754 binary16 <Half-precision_floating-point_format>` half-precision floating-point data type. Correspond to ``numpy.float16``.
262271

python/tensorstore/__init__.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ float8_e5m2: dtype
7676
"8-bit floating-point data type.\n\nDetails in https://github.com/jax-ml/ml_dtypes#float8_e5m2\n\nGroup:\n Data types\n"
7777
float8_e5m2fnuz: dtype
7878
"8-bit floating-point data type.\n\nDetails in https://github.com/jax-ml/ml_dtypes#float8_e5m2fnuz\n\nGroup:\n Data types\n"
79+
float4_e2m1fn: dtype
80+
"8-bit floating-point data type.\n\nDetails in https://github.com/jax-ml/ml_dtypes#float4_e2m1fn\n\nGroup:\n Data types\n"
7981
float16: dtype
8082
":wikipedia:`IEEE 754 binary16 <Half-precision_floating-point_format>` half-precision floating-point data type. Correspond to ``numpy.float16``.\n\nGroup:\n Data types\n"
8183
bfloat16: dtype
@@ -169,6 +171,7 @@ __all__ = [
169171
"experimental_update_verbose_logging",
170172
"float16",
171173
"float32",
174+
"float4_e2m1fn",
172175
"float64",
173176
"float8_e3m4",
174177
"float8_e4m3b11fnuz",

python/tensorstore/data_type.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class CustomDTypes {
8585
} \
8686
/**/
8787
TENSORSTORE_FOR_EACH_FLOAT8_DATA_TYPE(TENSORSTORE_INTERNAL_DO_ADD_TO_MAP)
88+
TENSORSTORE_FOR_EACH_MXFLOAT_DATA_TYPE(TENSORSTORE_INTERNAL_DO_ADD_TO_MAP)
8889
TENSORSTORE_FOR_EACH_LOW_PRECISION_INT_DATA_TYPE(
8990
TENSORSTORE_INTERNAL_DO_ADD_TO_MAP)
9091
#undef TENSORSTORE_INTERNAL_DO_ADD_TO_MAP
@@ -151,6 +152,8 @@ int GetNumpyTypeNum(DataType dtype) {
151152
TENSORSTORE_INTERNAL_DO_GET_NPY_TYPE_NUM_CASE(bfloat16_t)
152153
TENSORSTORE_FOR_EACH_FLOAT8_DATA_TYPE(
153154
TENSORSTORE_INTERNAL_DO_GET_NPY_TYPE_NUM_CASE)
155+
TENSORSTORE_FOR_EACH_MXFLOAT_DATA_TYPE(
156+
TENSORSTORE_INTERNAL_DO_GET_NPY_TYPE_NUM_CASE)
154157
TENSORSTORE_FOR_EACH_LOW_PRECISION_INT_DATA_TYPE(
155158
TENSORSTORE_INTERNAL_DO_GET_NPY_TYPE_NUM_CASE)
156159
#undef TENSORSTORE_INTERNAL_DO_GET_NPY_TYPE_NUM_CASE

python/tensorstore/tests/custom_dtypes_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
ml_dtypes.float8_e4m3b11fnuz.dtype,
2929
ml_dtypes.float8_e5m2.dtype,
3030
ml_dtypes.float8_e5m2fnuz.dtype,
31+
ml_dtypes.float4_e2m1fn.dtype,
3132
ml_dtypes.bfloat16.dtype,
3233
ts.int4.numpy_dtype,
3334
ts.float8_e3m4.numpy_dtype,
@@ -36,6 +37,7 @@
3637
ts.float8_e4m3b11fnuz.numpy_dtype,
3738
ts.float8_e5m2.numpy_dtype,
3839
ts.float8_e5m2fnuz.numpy_dtype,
40+
ts.float4_e2m1fn.numpy_dtype,
3941
ts.bfloat16.numpy_dtype,
4042
np.dtype("int4"),
4143
np.dtype("float8_e3m4"),
@@ -44,6 +46,7 @@
4446
np.dtype("float8_e4m3b11fnuz"),
4547
np.dtype("float8_e5m2"),
4648
np.dtype("float8_e5m2fnuz"),
49+
np.dtype("float4_e2m1fn"),
4750
np.dtype("bfloat16"),
4851
]
4952

tensorstore/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ licenses(["notice"])
77

88
exports_files(["LICENSE"])
99

10+
exports_files(["generate_data_type.py"])
11+
1012
tensorstore_cc_library(
1113
name = "array",
1214
srcs = [
@@ -443,6 +445,7 @@ tensorstore_cc_library(
443445
"//tensorstore/util:float8",
444446
"//tensorstore/util:int2",
445447
"//tensorstore/util:int4",
448+
"//tensorstore/util:mxfloat",
446449
"//tensorstore/util:result",
447450
"//tensorstore/util:str_cat",
448451
"//tensorstore/util:utf8_string",

tensorstore/data_type.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,9 @@ TENSORSTORE_INTERNAL_INHERITED_CONVERT( //
418418
::tensorstore::dtypes::complex64_t,
419419
::tensorstore::dtypes::float8_e5m2fnuz_t,
420420
internal_data_type::ComplexNumericConvertDataType)
421+
TENSORSTORE_INTERNAL_INHERITED_CONVERT( //
422+
::tensorstore::dtypes::complex64_t, ::tensorstore::dtypes::float4_e2m1fn_t,
423+
internal_data_type::ComplexNumericConvertDataType)
421424
TENSORSTORE_INTERNAL_INHERITED_CONVERT( //
422425
::tensorstore::dtypes::complex64_t, ::tensorstore::dtypes::float16_t,
423426
internal_data_type::ComplexNumericConvertDataType)
@@ -451,6 +454,9 @@ TENSORSTORE_INTERNAL_INHERITED_CONVERT( //
451454
::tensorstore::dtypes::complex128_t,
452455
::tensorstore::dtypes::float8_e5m2fnuz_t,
453456
internal_data_type::ComplexNumericConvertDataType)
457+
TENSORSTORE_INTERNAL_INHERITED_CONVERT( //
458+
::tensorstore::dtypes::complex128_t, ::tensorstore::dtypes::float4_e2m1fn_t,
459+
internal_data_type::ComplexNumericConvertDataType)
454460
TENSORSTORE_INTERNAL_INHERITED_CONVERT( //
455461
::tensorstore::dtypes::complex128_t, ::tensorstore::dtypes::float16_t,
456462
internal_data_type::ComplexNumericConvertDataType)

tensorstore/data_type.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
#include "tensorstore/util/float8.h"
8484
#include "tensorstore/util/int2.h"
8585
#include "tensorstore/util/int4.h"
86+
#include "tensorstore/util/mxfloat.h"
8687
#include "tensorstore/util/str_cat.h"
8788
#include "tensorstore/util/utf8_string.h"
8889

@@ -171,6 +172,9 @@ using float8_e5m2_t = ::tensorstore::Float8e5m2;
171172
using float8_e5m2fnuz_t = ::tensorstore::Float8e5m2fnuz;
172173
///
173174
/// \ingroup data types
175+
using float4_e2m1fn_t = ::tensorstore::Float4e2m1fn;
176+
///
177+
/// \ingroup data types
174178
using bfloat16_t = ::tensorstore::BFloat16;
175179

176180
/// :wikipedia:`IEEE 754 binary16<Half-precision_floating-point_format>`
@@ -243,6 +247,7 @@ enum class DataTypeId {
243247
float8_e4m3b11fnuz_t,
244248
float8_e5m2_t,
245249
float8_e5m2fnuz_t,
250+
float4_e2m1fn_t,
246251
float16_t,
247252
bfloat16_t,
248253
float32_t,
@@ -295,8 +300,13 @@ inline constexpr size_t kNumDataTypeIds =
295300
X(float8_e5m2fnuz_t, ##__VA_ARGS__) \
296301
/**/
297302

303+
#define TENSORSTORE_FOR_EACH_MXFLOAT_DATA_TYPE(X, ...) \
304+
X(float4_e2m1fn_t, ##__VA_ARGS__) \
305+
/**/
306+
298307
#define TENSORSTORE_FOR_EACH_LOW_PRECISION_FLOAT_DATA_TYPE(X, ...) \
299308
TENSORSTORE_FOR_EACH_FLOAT8_DATA_TYPE(X, ##__VA_ARGS__) \
309+
TENSORSTORE_FOR_EACH_MXFLOAT_DATA_TYPE(X, ##__VA_ARGS__) \
300310
X(float16_t, ##__VA_ARGS__) \
301311
X(bfloat16_t, ##__VA_ARGS__) \
302312
/**/

0 commit comments

Comments
 (0)