1+ // ===- NumericUtils.cpp - numeric utilities ---------------------*- C++ -*-===//
2+ //
3+ // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+ // See https://llvm.org/LICENSE.txt for license information.
5+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+ //
7+ // ===----------------------------------------------------------------------===//
8+ #include " gc/Transforms/Utils/NumericUtils.h"
9+
10+ namespace mlir {
11+ namespace gc {
12+
13+ const uint32_t kF32MantiBits = 23 ;
14+ const uint32_t kF32HalfMantiBitDiff = 13 ;
15+ const uint32_t kF32HalfBitDiff = 16 ;
16+ const Float32Bits kF32Magic = {113 << kF32MantiBits };
17+ const uint32_t kF32HalfExpAdjust = (127 - 15 ) << kF32MantiBits ;
18+ const uint32_t kF32BfMantiBitDiff = 16 ;
19+
20+ // / Constructs the 16 bit representation for a half precision value from a float
21+ // / value. This implementation is adapted from Eigen.
22+ uint16_t float2half (float floatValue) {
23+ const Float32Bits inf = {255 << kF32MantiBits };
24+ const Float32Bits f16max = {(127 + 16 ) << kF32MantiBits };
25+ const Float32Bits denormMagic = {((127 - 15 ) + (kF32MantiBits - 10 ) + 1 )
26+ << kF32MantiBits };
27+ uint32_t signMask = 0x80000000u ;
28+ uint16_t halfValue = static_cast <uint16_t >(0x0u );
29+ Float32Bits f;
30+ f.f = floatValue;
31+ uint32_t sign = f.u & signMask;
32+ f.u ^= sign;
33+
34+ if (f.u >= f16max.u ) {
35+ const uint32_t halfQnan = 0x7e00 ;
36+ const uint32_t halfInf = 0x7c00 ;
37+ // Inf or NaN (all exponent bits set).
38+ halfValue = (f.u > inf.u ) ? halfQnan : halfInf; // NaN->qNaN and Inf->Inf
39+ } else {
40+ // (De)normalized number or zero.
41+ if (f.u < kF32Magic .u ) {
42+ // The resulting FP16 is subnormal or zero.
43+ //
44+ // Use a magic value to align our 10 mantissa bits at the bottom of the
45+ // float. As long as FP addition is round-to-nearest-even this works.
46+ f.f += denormMagic.f ;
47+
48+ halfValue = static_cast <uint16_t >(f.u - denormMagic.u );
49+ } else {
50+ uint32_t mantOdd =
51+ (f.u >> kF32HalfMantiBitDiff ) & 1 ; // Resulting mantissa is odd.
52+
53+ // Update exponent, rounding bias part 1. The following expressions are
54+ // equivalent to `f.u += ((unsigned int)(15 - 127) << kF32MantiBits) +
55+ // 0xfff`, but without arithmetic overflow.
56+ f.u += 0xc8000fffU ;
57+ // Rounding bias part 2.
58+ f.u += mantOdd;
59+ halfValue = static_cast <uint16_t >(f.u >> kF32HalfMantiBitDiff );
60+ }
61+ }
62+
63+ halfValue |= static_cast <uint16_t >(sign >> kF32HalfBitDiff );
64+ return halfValue;
65+ }
66+
67+ // / Converts the 16 bit representation of a half precision value to a float
68+ // / value. This implementation is adapted from Eigen.
69+ float half2float (uint16_t halfValue) {
70+ const uint32_t shiftedExp =
71+ 0x7c00 << kF32HalfMantiBitDiff ; // Exponent mask after shift.
72+
73+ // Initialize the float representation with the exponent/mantissa bits.
74+ Float32Bits f = {
75+ static_cast <uint32_t >((halfValue & 0x7fff ) << kF32HalfMantiBitDiff )};
76+ const uint32_t exp = shiftedExp & f.u ;
77+ f.u += kF32HalfExpAdjust ; // Adjust the exponent
78+
79+ // Handle exponent special cases.
80+ if (exp == shiftedExp) {
81+ // Inf/NaN
82+ f.u += kF32HalfExpAdjust ;
83+ } else if (exp == 0 ) {
84+ // Zero/Denormal?
85+ f.u += 1 << kF32MantiBits ;
86+ f.f -= kF32Magic .f ;
87+ }
88+
89+ f.u |= (halfValue & 0x8000 ) << kF32HalfBitDiff ; // Sign bit.
90+ return f.f ;
91+ }
92+
93+ // Constructs the 16 bit representation for a bfloat value from a float value.
94+ // This implementation is adapted from Eigen.
95+ uint16_t float2bfloat (float floatValue) {
96+ if (std::isnan (floatValue))
97+ return std::signbit (floatValue) ? 0xFFC0 : 0x7FC0 ;
98+
99+ Float32Bits floatBits;
100+ floatBits.f = floatValue;
101+ uint16_t bfloatBits;
102+
103+ // Least significant bit of resulting bfloat.
104+ uint32_t lsb = (floatBits.u >> kF32BfMantiBitDiff ) & 1 ;
105+ uint32_t roundingBias = 0x7fff + lsb;
106+ floatBits.u += roundingBias;
107+ bfloatBits = static_cast <uint16_t >(floatBits.u >> kF32BfMantiBitDiff );
108+ return bfloatBits;
109+ }
110+
111+ // Converts the 16 bit representation of a bfloat value to a float value. This
112+ // implementation is adapted from Eigen.
113+ float bfloat2float (uint16_t bfloatBits) {
114+ Float32Bits floatBits;
115+ floatBits.u = static_cast <uint32_t >(bfloatBits) << kF32BfMantiBitDiff ;
116+ return floatBits.f ;
117+ }
118+
119+ std::variant<float , int64_t > numeric_limits_minimum (Type type) {
120+ Type t1 = getElementTypeOrSelf (type);
121+ if (t1.isF32 ()) {
122+ return -std::numeric_limits<float >::infinity ();
123+ } else if (t1.isBF16 ()) {
124+ return bfloat2float (float2bfloat (-std::numeric_limits<float >::infinity ()));
125+ } else if (t1.isF16 ()) {
126+ return (float )half2float (
127+ float2half (-std::numeric_limits<float >::infinity ()));
128+ } else if (t1.isSignedInteger (8 )) {
129+ return int64_t (-128 );
130+ } else if (t1.isSignedInteger (32 )) {
131+ return int64_t (std::numeric_limits<int32_t >::min ());
132+ } else if (t1.isSignlessInteger (8 ) or t1.isSignlessInteger (32 )) {
133+ return int64_t (0 );
134+ } else {
135+ llvm_unreachable (" unsupported data type" );
136+ return (int64_t )0 ;
137+ }
138+ }
139+
140+ std::variant<float , int64_t > numericLimitsMaximum (Type type) {
141+ Type t1 = getElementTypeOrSelf (type);
142+ if (t1.isF32 ()) {
143+ return std::numeric_limits<float >::infinity ();
144+ } else if (t1.isBF16 ()) {
145+ return bfloat2float (float2bfloat (std::numeric_limits<float >::infinity ()));
146+ } else if (t1.isF16 ()) {
147+ return (float )half2float (
148+ float2half (std::numeric_limits<float >::infinity ()));
149+ } else if (t1.isSignedInteger (8 )) {
150+ return int64_t (127 );
151+ } else if (t1.isSignedInteger (32 )) {
152+ return int64_t (std::numeric_limits<int32_t >::max ());
153+ } else if (t1.isSignlessInteger (8 )) {
154+ return int64_t (255 );
155+ } else if (t1.isSignedInteger (32 )) {
156+ return int64_t (std::numeric_limits<uint32_t >::max ());
157+ } else {
158+ llvm_unreachable (" unsupported data type" );
159+ return (int64_t )0 ;
160+ }
161+ }
162+
163+ } // namespace gc
164+ } // namespace mlir
0 commit comments