Skip to content

Commit 4783e0a

Browse files
authored
[Core] Fix debug node input output compilation after Fp4 support was enabled in ORT (microsoft#25940)
### Description As title ### Motivation and Context Follow-up fixes to microsoft#25767
1 parent 0047263 commit 4783e0a

File tree

2 files changed

+121
-106
lines changed

2 files changed

+121
-106
lines changed

onnxruntime/core/framework/print_tensor_statistics_utils.h

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -94,33 +94,36 @@ void PrintCommonStats(const T* data, size_t count, TensorStatisticsData& tensor_
9494
}
9595
}
9696

97-
#define DEF_PRINT_COMMON_STATS_INT4(INT4_TYPE) \
98-
template <> \
99-
inline void PrintCommonStats<INT4_TYPE>( \
100-
const INT4_TYPE* data, size_t count, TensorStatisticsData&) { \
101-
using UnpackedType = typename INT4_TYPE::UnpackedType; \
102-
UnpackedType min = data[0].GetElem(0); \
103-
UnpackedType max = min; \
104-
for (size_t i = 1; i < count; i++) { \
105-
auto indices = INT4_TYPE::GetTensorElemIndices(i); \
106-
auto value = data[indices.first].GetElem(indices.second); \
107-
if (value > max) { \
108-
max = value; \
109-
} \
110-
if (value < min) { \
111-
min = value; \
112-
} \
113-
} \
114-
\
115-
std::cout << "Min="; \
116-
PrintValue(min); \
117-
\
118-
std::cout << ",Max="; \
119-
PrintValue(max); \
97+
#define DEF_PRINT_COMMON_STATS_4BIT(FOUR_BIT_TYPE) \
98+
template <> \
99+
inline void PrintCommonStats<FOUR_BIT_TYPE>( \
100+
const FOUR_BIT_TYPE* data, size_t count, TensorStatisticsData&) { \
101+
using UnpackedType = typename FOUR_BIT_TYPE::UnpackedType; \
102+
UnpackedType min = data[0].GetElem(0); \
103+
UnpackedType max = min; \
104+
for (size_t i = 1; i < count; i++) { \
105+
auto indices = FOUR_BIT_TYPE::GetTensorElemIndices(i); \
106+
auto value = data[indices.first].GetElem(indices.second); \
107+
if (value > max) { \
108+
max = value; \
109+
} \
110+
if (value < min) { \
111+
min = value; \
112+
} \
113+
} \
114+
\
115+
std::cout << "Min="; \
116+
PrintValue(min); \
117+
\
118+
std::cout << ",Max="; \
119+
PrintValue(max); \
120120
}
121121

122-
DEF_PRINT_COMMON_STATS_INT4(Int4x2)
123-
DEF_PRINT_COMMON_STATS_INT4(UInt4x2)
122+
DEF_PRINT_COMMON_STATS_4BIT(Int4x2)
123+
DEF_PRINT_COMMON_STATS_4BIT(UInt4x2)
124+
#if !defined(DISABLE_FLOAT4_TYPES)
125+
DEF_PRINT_COMMON_STATS_4BIT(Float4E2M1x2)
126+
#endif
124127

125128
template <typename T>
126129
void PrintHalfStats(const T* data, size_t count) {

0 commit comments

Comments
 (0)