Skip to content

Commit 0dfe783

Browse files
committed
Merge branch 'rocm7.1_internal_testing' into rel-1.23.1
2 parents d9b2048 + 8d330d8 commit 0dfe783

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+2052
-369
lines changed

.github/workflows/linux_minimal_build.yml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ jobs:
325325
--build_wheel
326326
--use_binskim_compliant_compile_flags
327327
--disable_ml_ops
328-
--disable_types sparsetensor float8 optional
328+
--disable_types sparsetensor float4 float8 optional
329329
--include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config
330330
--cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF
331331
@@ -341,7 +341,7 @@ jobs:
341341
--build_wheel
342342
--use_binskim_compliant_compile_flags
343343
--disable_ml_ops
344-
--disable_types sparsetensor float8 optional
344+
--disable_types sparsetensor float4 float8 optional
345345
--include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config
346346
--cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF
347347
@@ -358,7 +358,7 @@ jobs:
358358
--build_wheel
359359
--use_binskim_compliant_compile_flags
360360
--disable_ml_ops
361-
--disable_types sparsetensor float8 optional
361+
--disable_types sparsetensor float4 float8 optional
362362
--include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config
363363
--cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF
364364
@@ -408,7 +408,7 @@ jobs:
408408
--disable_ml_ops
409409
--skip_tests
410410
--enable_reduced_operator_type_support
411-
--disable_types sparsetensor optional float8
411+
--disable_types sparsetensor optional float4 float8
412412
--include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config
413413
--cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF
414414
@@ -427,7 +427,7 @@ jobs:
427427
--disable_ml_ops
428428
--skip_tests
429429
--enable_reduced_operator_type_support
430-
--disable_types sparsetensor optional float8
430+
--disable_types sparsetensor optional float4 float8
431431
--include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config
432432
--cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF
433433
@@ -483,7 +483,7 @@ jobs:
483483
--disable_ml_ops
484484
--skip_tests
485485
--enable_reduced_operator_type_support
486-
--disable_types sparsetensor optional float8
486+
--disable_types sparsetensor optional float4 float8
487487
--include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config
488488
--cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF
489489
@@ -502,7 +502,7 @@ jobs:
502502
--disable_ml_ops
503503
--skip_tests
504504
--enable_reduced_operator_type_support
505-
--disable_types sparsetensor optional float8
505+
--disable_types sparsetensor optional float4 float8
506506
--include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config
507507
--cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF
508508

cmake/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ option(onnxruntime_DISABLE_ML_OPS "Disable traditional ML ops" OFF)
155155
option(onnxruntime_DISABLE_SPARSE_TENSORS "Disable sparse tensors data types" OFF)
156156
option(onnxruntime_DISABLE_OPTIONAL_TYPE "Disable optional type" OFF)
157157
option(onnxruntime_DISABLE_FLOAT8_TYPES "Disable float 8 types" OFF)
158+
option(onnxruntime_DISABLE_FLOAT4_TYPES "Disable float 4 types" OFF)
158159
option(onnxruntime_MINIMAL_BUILD "Exclude as much as possible from the build. Support ORT format models. No support for ONNX format models." OFF)
159160
option(onnxruntime_CLIENT_PACKAGE_BUILD "Enables default settings that are more appropriate for client/on-device workloads." OFF)
160161
cmake_dependent_option(onnxruntime_DISABLE_RTTI "Disable RTTI" ON "NOT onnxruntime_ENABLE_PYTHON;NOT onnxruntime_USE_CUDA" OFF)
@@ -1029,6 +1030,10 @@ function(onnxruntime_set_compile_flags target_name)
10291030
target_compile_definitions(${target_name} PRIVATE DISABLE_FLOAT8_TYPES)
10301031
endif()
10311032

1033+
if (onnxruntime_DISABLE_FLOAT4_TYPES)
1034+
target_compile_definitions(${target_name} PRIVATE DISABLE_FLOAT4_TYPES)
1035+
endif()
1036+
10321037
if (onnxruntime_ENABLE_ATEN)
10331038
target_compile_definitions(${target_name} PRIVATE ENABLE_ATEN)
10341039
endif()

docs/OperatorKernels.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -625,10 +625,12 @@ Do not modify directly.*
625625
|||14|**T** = tensor(double), tensor(float), tensor(float16)<br/> **U** = tensor(double), tensor(float), tensor(float16)|
626626
|||[9, 13]|**T** = tensor(double), tensor(float), tensor(float16)|
627627
|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)|
628-
|Cast|*in* input:**T1**<br> *out* output:**T2**|19+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
629-
|||[13, 18]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
630-
|||[9, 12]|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
631-
|||[6, 8]|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
628+
|Cast|*in* input:**T1**<br> *out* output:**T2**|23+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
629+
|||[21, 22]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
630+
|||[19, 20]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
631+
|||[13, 18]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
632+
|||[9, 12]|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
633+
|||[6, 8]|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
632634
|Ceil|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
633635
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
634636
|Clip|*in* input:**T**<br> *in* min:**T**<br> *in* max:**T**<br> *out* output:**T**<br><br>or<br><br>*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int64), tensor(int8), tensor(uint64), tensor(uint8)|

include/onnxruntime/core/framework/data_types.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "core/framework/float8.h"
1717
#include "core/framework/float16.h"
1818
#include "core/framework/int4.h"
19+
#include "core/framework/float4.h"
1920
#include "core/graph/onnx_protobuf.h"
2021
#include "core/framework/to_tensor_proto_element_type.h"
2122

@@ -209,6 +210,7 @@ class DataTypeImpl {
209210
static const std::vector<MLDataType>& AllTensorTypesIRv4();
210211
static const std::vector<MLDataType>& AllTensorTypesIRv9();
211212
static const std::vector<MLDataType>& AllTensorTypesIRv10();
213+
static const std::vector<MLDataType>& AllTensorTypesIRv11();
212214

213215
static const std::vector<MLDataType>& AllFixedSizeTensorTypes(); // up to IR4 (no float 8), deprecated
214216
static const std::vector<MLDataType>& AllFixedSizeTensorTypesIRv4();
@@ -287,6 +289,10 @@ struct IsTensorContainedType : public IsAnyOf<T, float, uint8_t, int8_t, uint16_
287289
#if !defined(DISABLE_FLOAT8_TYPES)
288290
,
289291
Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ
292+
#endif
293+
#if !defined(DISABLE_FLOAT4_TYPES)
294+
,
295+
Float4E2M1x2
290296
#endif
291297
> {
292298
};
@@ -302,6 +308,10 @@ struct IsSparseTensorContainedType : public IsAnyOf<T, float, uint8_t, int8_t, u
302308
#if !defined(DISABLE_FLOAT8_TYPES)
303309
,
304310
Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ
311+
#endif
312+
#if !defined(DISABLE_FLOAT4_TYPES)
313+
,
314+
Float4E2M1x2
305315
#endif
306316
> {
307317
};
@@ -921,7 +931,7 @@ class OpaqueType : public NonTensorType<T> {
921931
*
922932
* \details This class contains an integer constant that can be
923933
* used for input data type dispatching. This class also stores the number of subelements per size units.
924-
* Example: For int4, the size unit is 1 byte and the number of subelements is 2.
934+
* Example: For float4/int4, the size unit is 1 byte and the number of subelements is 2.
925935
*
926936
*/
927937
class PrimitiveDataTypeBase : public DataTypeImpl {
@@ -1101,6 +1111,7 @@ inline const PrimitiveDataTypeBase* DataTypeImpl::AsPrimitiveDataType() const {
11011111
// Registers a subbyte primitive.
11021112
// Examples:
11031113
// - Int4x2 stores 2 packed 4-bit elements in 1 byte: ORT_*_SUBBYTE_TYPE(Int4x2, 2)
1114+
// - Float4E2M1x2 stores 2 packed 4-bit elements in 1 byte: ORT_*_SUBBYTE_TYPE(Float4E2M1x2, 2)
11041115
// - [not supported] Int3x8 could store 8 packed 3-bit elements in 3 bytes: ORT_*_SUBBYTE_TYPE(Int3x8, 8)
11051116
#define ORT_REGISTER_PRIM_SUBBYTE_TYPE(TYPE, NUM_SUB_ELEMS) \
11061117
template <> \

0 commit comments

Comments
 (0)