Skip to content

Commit 09e0727

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
add up to Uint64 (#6825)
Summary: Add support up to uint64_t. People are wanting uint16 for quant purposes. Differential Revision: D65846964
1 parent ee32ea3 commit 09e0727

File tree

3 files changed

+119
-25
lines changed

3 files changed

+119
-25
lines changed

runtime/core/exec_aten/util/scalar_type_util.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,31 @@ struct is_qint_type
503503
: std::integral_constant<bool, isQIntType(CppTypeToScalarType<T>::value)> {
504504
};
505505

506+
constexpr bool isFloat8Type(::executorch::aten::ScalarType t) {
507+
// Don't forget to extend this when adding new QInt types
508+
return t == ::executorch::aten::ScalarType::Float8_e5m2 ||
509+
t == ::executorch::aten::ScalarType::Float8_e4m3fn ||
510+
t == ::executorch::aten::ScalarType::Float8_e5m2fnuz ||
511+
t == ::executorch::aten::ScalarType::Float8_e4m3fnuz;
512+
}
513+
514+
template <typename T>
515+
struct is_float8_type
516+
: std::integral_constant<bool, isFloat8Type(CppTypeToScalarType<T>::value)> {
517+
};
518+
519+
constexpr bool isBarebonesUnsignedType(::executorch::aten::ScalarType t) {
520+
// Don't forget to extend this when adding new QInt types
521+
return t == ::executorch::aten::ScalarType::UInt16 ||
522+
t == ::executorch::aten::ScalarType::UInt32 ||
523+
t == ::executorch::aten::ScalarType::UInt64;
524+
}
525+
526+
template <typename T>
527+
struct is_barebones_unsigned_type
528+
: std::integral_constant<bool, isBarebonesUnsignedType(CppTypeToScalarType<T>::value)> {
529+
};
530+
506531
inline ::executorch::aten::ScalarType toQIntType(
507532
::executorch::aten::ScalarType t) {
508533
switch (t) {
@@ -883,6 +908,14 @@ struct promote_types {
883908
std::is_same<T1, T2>::value ||
884909
(!is_bits_type<T1>::value && !is_bits_type<T2>::value),
885910
"promote_types not valid for bits dtypes");
911+
static_assert(
912+
std::is_same<T1, T2>::value ||
913+
(!is_float8_type<T1>::value && !is_float8_type<T2>::value),
914+
"promote_types not valid for float8 dtypes");
915+
static_assert(
916+
std::is_same<T1, T2>::value ||
917+
(!is_barebones_unsigned_type<T1>::value && !is_barebones_unsigned_type<T2>::value),
918+
"promote_types not valid for barebones unsigned dtypes");
886919

887920
using promoted_type_not_respecting_half_to_float =
888921
typename internal::promote_types_lookup<T1, T2>::type;
@@ -945,6 +978,24 @@ inline ::executorch::aten::ScalarType promoteTypes(
945978
ET_CHECK_MSG(false, "promoteTypes not valid for bits dtypes");
946979
}
947980

981+
// For Float8 types, only allow exact match
982+
if (::executorch::runtime::isFloat8Type(a) && a == b) {
983+
return a;
984+
}
985+
if (::executorch::runtime::isFloat8Type(a) ||
986+
::executorch::runtime::isFloat8Type(b)) {
987+
ET_CHECK_MSG(false, "promoteTypes not valid for float8 dtypes");
988+
}
989+
990+
// For barebones uint types, only allow exact match
991+
if (::executorch::runtime::isBarebonesUnsignedType(a) && a == b) {
992+
return a;
993+
}
994+
if (::executorch::runtime::isBarebonesUnsignedType(a) ||
995+
::executorch::runtime::isBarebonesUnsignedType(b)) {
996+
ET_CHECK_MSG(false, "promoteTypes not valid for barebone unsigned dtypes");
997+
}
998+
948999
// 12 types are handled by this function, see the constexpr definitions above
9491000
const int NUM_PROMOTE_TYPES = 13;
9501001

@@ -1437,6 +1488,8 @@ using ::executorch::runtime::is_bits_type;
14371488
using ::executorch::runtime::is_complex_type;
14381489
using ::executorch::runtime::is_integral_type;
14391490
using ::executorch::runtime::is_qint_type;
1491+
using ::executorch::runtime::is_float8_type;
1492+
using ::executorch::runtime::is_barebones_unsigned_type;
14401493
using ::executorch::runtime::isBitsType;
14411494
using ::executorch::runtime::isComplexType;
14421495
using ::executorch::runtime::isFloatingType;

runtime/core/exec_aten/util/test/scalar_type_util_test.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,11 @@ struct promote_types_is_valid
170170
(!executorch::runtime::is_qint_type<T1>::value &&
171171
!executorch::runtime::is_qint_type<T2>::value &&
172172
!executorch::runtime::is_bits_type<T1>::value &&
173-
!executorch::runtime::is_bits_type<T2>::value))> {};
173+
!executorch::runtime::is_bits_type<T2>::value &&
174+
!executorch::runtime::is_float8_type<T1>::value &&
175+
!executorch::runtime::is_float8_type<T2>::value &&
176+
!executorch::runtime::is_barebones_unsigned_type<T1>::value &&
177+
!executorch::runtime::is_barebones_unsigned_type<T2>::value))> {};
174178

175179
template <typename T1, bool half_to_float>
176180
struct CompileTimePromoteTypesTestCase {

runtime/core/portable_type/scalar_type.h

Lines changed: 61 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,36 @@ namespace executorch {
4747
namespace runtime {
4848
namespace etensor {
4949

50+
// Placing a bunch of unused dtypes here as our macros don't make it easy
51+
// to skip scalar types defined in aten that we dont have.
52+
namespace unused_dtype {
53+
struct alignas(1) Float8_e5m2 {
54+
uint8_t x;
55+
using underlying = uint8_t;
56+
Float8_e5m2() = default;
57+
explicit Float8_e5m2(uint8_t val) : x(val) {}
58+
};
59+
struct alignas(1) Float8_e4m3fn {
60+
uint8_t x;
61+
using underlying = uint8_t;
62+
Float8_e4m3fn() = default;
63+
explicit Float8_e4m3fn(uint8_t val) : x(val) {}
64+
};
65+
struct alignas(1) Float8_e5m2fnuz {
66+
uint8_t x;
67+
using underlying = uint8_t;
68+
Float8_e5m2fnuz() = default;
69+
explicit Float8_e5m2fnuz(uint8_t val) : x(val) {}
70+
};
71+
struct alignas(1) Float8_e4m3fnuz {
72+
uint8_t x;
73+
using underlying = uint8_t;
74+
Float8_e4m3fnuz() = default;
75+
explicit Float8_e4m3fnuz(uint8_t val) : x(val) {}
76+
};
77+
78+
} // namespace unused_dtype
79+
5080
/**
5181
* Calls the provided macro on every ScalarType, providing the C type and the
5282
* ScalarType name to each call.
@@ -59,30 +89,37 @@ namespace etensor {
5989
* @param _ A macro that takes two parameters: the name of a C type, and the
6090
* name of the corresponding ScalarType enumerator.
6191
*/
62-
#define ET_FORALL_SCALAR_TYPES(_) \
63-
_(uint8_t, Byte) /* 0 */ \
64-
_(int8_t, Char) /* 1 */ \
65-
_(int16_t, Short) /* 2 */ \
66-
_(int32_t, Int) /* 3 */ \
67-
_(int64_t, Long) /* 4 */ \
68-
_(::torch::executor::Half, Half) /* 5 */ \
69-
_(float, Float) /* 6 */ \
70-
_(double, Double) /* 7 */ \
71-
_(::torch::executor::complex<::torch::executor::Half>, ComplexHalf) /* 8 */ \
72-
_(::torch::executor::complex<float>, ComplexFloat) /* 9 */ \
73-
_(::torch::executor::complex<double>, ComplexDouble) /* 10 */ \
74-
_(bool, Bool) /* 11 */ \
75-
_(::torch::executor::qint8, QInt8) /* 12 */ \
76-
_(::torch::executor::quint8, QUInt8) /* 13 */ \
77-
_(::torch::executor::qint32, QInt32) /* 14 */ \
78-
_(::torch::executor::BFloat16, BFloat16) /* 15 */ \
79-
_(::torch::executor::quint4x2, QUInt4x2) /* 16 */ \
80-
_(::torch::executor::quint2x4, QUInt2x4) /* 17 */ \
81-
_(::torch::executor::bits1x8, Bits1x8) /* 18 */ \
82-
_(::torch::executor::bits2x4, Bits2x4) /* 19 */ \
83-
_(::torch::executor::bits4x2, Bits4x2) /* 20 */ \
84-
_(::torch::executor::bits8, Bits8) /* 21 */ \
85-
_(::torch::executor::bits16, Bits16) /* 22 */
92+
#define ET_FORALL_SCALAR_TYPES(_) \
93+
_(uint8_t, Byte) /* 0 */ \
94+
_(int8_t, Char) /* 1 */ \
95+
_(int16_t, Short) /* 2 */ \
96+
_(int32_t, Int) /* 3 */ \
97+
_(int64_t, Long) /* 4 */ \
98+
_(::executorch::runtime::etensor::Half, Half) /* 5 */ \
99+
_(float, Float) /* 6 */ \
100+
_(double, Double) /* 7 */ \
101+
_(::executorch::runtime::etensor::complex<::torch::executor::Half>, ComplexHalf) /* 8 */ \
102+
_(::executorch::runtime::etensor::complex<float>, ComplexFloat) /* 9 */ \
103+
_(::executorch::runtime::etensor::complex<double>, ComplexDouble) /* 10 */ \
104+
_(bool, Bool) /* 11 */ \
105+
_(::executorch::runtime::etensor::qint8, QInt8) /* 12 */ \
106+
_(::executorch::runtime::etensor::quint8, QUInt8) /* 13 */ \
107+
_(::executorch::runtime::etensor::qint32, QInt32) /* 14 */ \
108+
_(::executorch::runtime::etensor::BFloat16, BFloat16) /* 15 */ \
109+
_(::executorch::runtime::etensor::quint4x2, QUInt4x2) /* 16 */ \
110+
_(::executorch::runtime::etensor::quint2x4, QUInt2x4) /* 17 */ \
111+
_(::executorch::runtime::etensor::bits1x8, Bits1x8) /* 18 */ \
112+
_(::executorch::runtime::etensor::bits2x4, Bits2x4) /* 19 */ \
113+
_(::executorch::runtime::etensor::bits4x2, Bits4x2) /* 20 */ \
114+
_(::executorch::runtime::etensor::bits8, Bits8) /* 21 */ \
115+
_(::executorch::runtime::etensor::bits16, Bits16) /* 22 */ \
116+
_(::executorch::runtime::etensor::unused_dtype::Float8_e5m2, Float8_e5m2) /* 23 */ \
117+
_(::executorch::runtime::etensor::unused_dtype::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \
118+
_(::executorch::runtime::etensor::unused_dtype::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \
119+
_(::executorch::runtime::etensor::unused_dtype::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \
120+
_(uint16_t, UInt16) /* 27 */ \
121+
_(uint32_t, UInt32) /* 28 */ \
122+
_(uint64_t, UInt64) /* 29 */
86123

87124
/**
88125
* Data types (dtypes) that can be used as element types in ETensors.

0 commit comments

Comments
 (0)