@@ -94,33 +94,36 @@ void PrintCommonStats(const T* data, size_t count, TensorStatisticsData& tensor_
9494 }
9595}
9696
97- #define DEF_PRINT_COMMON_STATS_INT4 ( INT4_TYPE ) \
98- template <> \
99- inline void PrintCommonStats<INT4_TYPE >( \
100- const INT4_TYPE * data, size_t count, TensorStatisticsData&) { \
101- using UnpackedType = typename INT4_TYPE ::UnpackedType; \
102- UnpackedType min = data[0 ].GetElem (0 ); \
103- UnpackedType max = min; \
104- for (size_t i = 1 ; i < count; i++) { \
105- auto indices = INT4_TYPE ::GetTensorElemIndices (i); \
106- auto value = data[indices.first ].GetElem (indices.second ); \
107- if (value > max) { \
108- max = value; \
109- } \
110- if (value < min) { \
111- min = value; \
112- } \
113- } \
114- \
115- std::cout << " Min=" ; \
116- PrintValue (min); \
117- \
118- std::cout << " ,Max=" ; \
119- PrintValue (max); \
97+ #define DEF_PRINT_COMMON_STATS_4BIT ( FOUR_BIT_TYPE ) \
98+ template <> \
99+ inline void PrintCommonStats<FOUR_BIT_TYPE >( \
100+ const FOUR_BIT_TYPE * data, size_t count, TensorStatisticsData&) { \
101+ using UnpackedType = typename FOUR_BIT_TYPE ::UnpackedType; \
102+ UnpackedType min = data[0 ].GetElem (0 ); \
103+ UnpackedType max = min; \
104+ for (size_t i = 1 ; i < count; i++) { \
105+ auto indices = FOUR_BIT_TYPE ::GetTensorElemIndices (i); \
106+ auto value = data[indices.first ].GetElem (indices.second ); \
107+ if (value > max) { \
108+ max = value; \
109+ } \
110+ if (value < min) { \
111+ min = value; \
112+ } \
113+ } \
114+ \
115+ std::cout << " Min=" ; \
116+ PrintValue (min); \
117+ \
118+ std::cout << " ,Max=" ; \
119+ PrintValue (max); \
120120 }
121121
122- DEF_PRINT_COMMON_STATS_INT4 (Int4x2)
123- DEF_PRINT_COMMON_STATS_INT4 (UInt4x2)
122+ DEF_PRINT_COMMON_STATS_4BIT (Int4x2)
123+ DEF_PRINT_COMMON_STATS_4BIT (UInt4x2)
124+ #if !defined(DISABLE_FLOAT4_TYPES)
125+ DEF_PRINT_COMMON_STATS_4BIT (Float4E2M1x2)
126+ #endif
124127
125128template <typename T>
126129void PrintHalfStats (const T* data, size_t count) {
0 commit comments