Skip to content

Commit 16a842a

Browse files
authored
Support fp4 type in ORT (microsoft#25767)
### Description onnx/onnx#6318 and onnx/onnx#6283 added FP4 support to ONNX. This change introduces the FP4 type in ORT and adds type support to one relevant operator (`Cast`) as a proof-of-concept for the type integration into ORT. More op support will be added on a need-basis. This change took inspiration from the following PRs: microsoft#14731 microsoft#22228 microsoft#20362 Some notes: 1) Only `tensor` type gets support for FP4 initially. Secondary types like `seq(tensor)`, `sparse_tensor`, `optional` do not get support (so as to not introduce unnecessary bloat to the framework without a solid use-case) 2) Flatbuffer related files receive no updates in this PR ### Motivation and Context Be able to run FP4 models with ORT
1 parent ef60e38 commit 16a842a

39 files changed

+1851
-246
lines changed

.github/workflows/linux_minimal_build.yml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ jobs:
369369
--build_wheel
370370
--use_binskim_compliant_compile_flags
371371
--disable_ml_ops
372-
--disable_types sparsetensor float8 optional
372+
--disable_types sparsetensor float4 float8 optional
373373
--include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config
374374
--cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF
375375
@@ -385,7 +385,7 @@ jobs:
385385
--build_wheel
386386
--use_binskim_compliant_compile_flags
387387
--disable_ml_ops
388-
--disable_types sparsetensor float8 optional
388+
--disable_types sparsetensor float4 float8 optional
389389
--include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config
390390
--cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF
391391
@@ -402,7 +402,7 @@ jobs:
402402
--build_wheel
403403
--use_binskim_compliant_compile_flags
404404
--disable_ml_ops
405-
--disable_types sparsetensor float8 optional
405+
--disable_types sparsetensor float4 float8 optional
406406
--include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config
407407
--cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF
408408
@@ -459,7 +459,7 @@ jobs:
459459
--disable_ml_ops
460460
--skip_tests
461461
--enable_reduced_operator_type_support
462-
--disable_types sparsetensor optional float8
462+
--disable_types sparsetensor optional float4 float8
463463
--include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config
464464
--cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF
465465
@@ -478,7 +478,7 @@ jobs:
478478
--disable_ml_ops
479479
--skip_tests
480480
--enable_reduced_operator_type_support
481-
--disable_types sparsetensor optional float8
481+
--disable_types sparsetensor optional float4 float8
482482
--include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config
483483
--cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF
484484
@@ -540,7 +540,7 @@ jobs:
540540
--disable_ml_ops
541541
--skip_tests
542542
--enable_reduced_operator_type_support
543-
--disable_types sparsetensor optional float8
543+
--disable_types sparsetensor optional float4 float8
544544
--include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config
545545
--cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF
546546
@@ -559,7 +559,7 @@ jobs:
559559
--disable_ml_ops
560560
--skip_tests
561561
--enable_reduced_operator_type_support
562-
--disable_types sparsetensor optional float8
562+
--disable_types sparsetensor optional float4 float8
563563
--include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config
564564
--cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF
565565

cmake/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ option(onnxruntime_DISABLE_ML_OPS "Disable traditional ML ops" OFF)
159159
option(onnxruntime_DISABLE_SPARSE_TENSORS "Disable sparse tensors data types" OFF)
160160
option(onnxruntime_DISABLE_OPTIONAL_TYPE "Disable optional type" OFF)
161161
option(onnxruntime_DISABLE_FLOAT8_TYPES "Disable float 8 types" OFF)
162+
option(onnxruntime_DISABLE_FLOAT4_TYPES "Disable float 4 types" OFF)
162163
option(onnxruntime_MINIMAL_BUILD "Exclude as much as possible from the build. Support ORT format models. No support for ONNX format models." OFF)
163164
option(onnxruntime_CLIENT_PACKAGE_BUILD "Enables default settings that are more appropriate for client/on-device workloads." OFF)
164165
cmake_dependent_option(onnxruntime_DISABLE_RTTI "Disable RTTI" ON "NOT onnxruntime_ENABLE_PYTHON;NOT onnxruntime_USE_CUDA" OFF)
@@ -1035,6 +1036,10 @@ function(onnxruntime_set_compile_flags target_name)
10351036
target_compile_definitions(${target_name} PRIVATE DISABLE_FLOAT8_TYPES)
10361037
endif()
10371038

1039+
if (onnxruntime_DISABLE_FLOAT4_TYPES)
1040+
target_compile_definitions(${target_name} PRIVATE DISABLE_FLOAT4_TYPES)
1041+
endif()
1042+
10381043
if (onnxruntime_ENABLE_ATEN)
10391044
target_compile_definitions(${target_name} PRIVATE ENABLE_ATEN)
10401045
endif()

docs/OperatorKernels.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -622,10 +622,12 @@ Do not modify directly.*
622622
|||14|**T** = tensor(double), tensor(float), tensor(float16)<br/> **U** = tensor(double), tensor(float), tensor(float16)|
623623
|||[9, 13]|**T** = tensor(double), tensor(float), tensor(float16)|
624624
|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)|
625-
|Cast|*in* input:**T1**<br> *out* output:**T2**|19+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
626-
|||[13, 18]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
627-
|||[9, 12]|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
628-
|||[6, 8]|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
625+
|Cast|*in* input:**T1**<br> *out* output:**T2**|23+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
626+
|||[21, 22]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
627+
|||[19, 20]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
628+
|||[13, 18]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
629+
|||[9, 12]|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
630+
|||[6, 8]|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
629631
|Ceil|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
630632
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
631633
|Clip|*in* input:**T**<br> *in* min:**T**<br> *in* max:**T**<br> *out* output:**T**<br><br>or<br><br>*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int64), tensor(int8), tensor(uint64), tensor(uint8)|

include/onnxruntime/core/framework/data_types.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "core/framework/float8.h"
1717
#include "core/framework/float16.h"
1818
#include "core/framework/int4.h"
19+
#include "core/framework/float4.h"
1920
#include "core/graph/onnx_protobuf.h"
2021
#include "core/framework/to_tensor_proto_element_type.h"
2122

@@ -209,6 +210,7 @@ class DataTypeImpl {
209210
static const std::vector<MLDataType>& AllTensorTypesIRv4();
210211
static const std::vector<MLDataType>& AllTensorTypesIRv9();
211212
static const std::vector<MLDataType>& AllTensorTypesIRv10();
213+
static const std::vector<MLDataType>& AllTensorTypesIRv11();
212214

213215
static const std::vector<MLDataType>& AllFixedSizeTensorTypes(); // up to IR4 (no float 8), deprecated
214216
static const std::vector<MLDataType>& AllFixedSizeTensorTypesIRv4();
@@ -287,6 +289,10 @@ struct IsTensorContainedType : public IsAnyOf<T, float, uint8_t, int8_t, uint16_
287289
#if !defined(DISABLE_FLOAT8_TYPES)
288290
,
289291
Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ
292+
#endif
293+
#if !defined(DISABLE_FLOAT4_TYPES)
294+
,
295+
Float4E2M1x2
290296
#endif
291297
> {
292298
};
@@ -302,6 +308,10 @@ struct IsSparseTensorContainedType : public IsAnyOf<T, float, uint8_t, int8_t, u
302308
#if !defined(DISABLE_FLOAT8_TYPES)
303309
,
304310
Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ
311+
#endif
312+
#if !defined(DISABLE_FLOAT4_TYPES)
313+
,
314+
Float4E2M1x2
305315
#endif
306316
> {
307317
};
@@ -921,7 +931,7 @@ class OpaqueType : public NonTensorType<T> {
921931
*
922932
* \details This class contains an integer constant that can be
923933
* used for input data type dispatching. This class also stores the number of subelements per size units.
924-
* Example: For int4, the size unit is 1 byte and the number of subelements is 2.
934+
* Example: For float4/int4, the size unit is 1 byte and the number of subelements is 2.
925935
*
926936
*/
927937
class PrimitiveDataTypeBase : public DataTypeImpl {
@@ -1101,6 +1111,7 @@ inline const PrimitiveDataTypeBase* DataTypeImpl::AsPrimitiveDataType() const {
11011111
// Registers a subbyte primitive.
11021112
// Examples:
11031113
// - Int4x2 stores 2 packed 4-bit elements in 1 byte: ORT_*_SUBBYTE_TYPE(Int4x2, 2)
1114+
// - Float4E2M1x2 stores 2 packed 4-bit elements in 1 byte: ORT_*_SUBBYTE_TYPE(Float4E2M1x2, 2)
11041115
// - [not supported] Int3x8 could store 8 packed 3-bit elements in 3 bytes: ORT_*_SUBBYTE_TYPE(Int3x8, 8)
11051116
#define ORT_REGISTER_PRIM_SUBBYTE_TYPE(TYPE, NUM_SUB_ELEMS) \
11061117
template <> \

0 commit comments

Comments
 (0)