Skip to content

Commit 270204a

Browse files
authored
Merge pull request #339 from cppalliance/fma
2 parents 5037e2c + 4c200e4 commit 270204a

File tree

4 files changed

+139
-14
lines changed

4 files changed

+139
-14
lines changed

include/boost/decimal/cmath.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
#define BOOST_DECIMAL_DEC_NAN std::numeric_limits<boost::decimal::decimal64>::signaling_NaN()
6262
#define BOOST_DECIMAL_FP_FAST_FMAD32 1
6363
#define BOOST_DECIMAL_FP_FAST_FMAD64 1
64+
#define BOOST_DECIMAL_FP_FAST_FMAD128 1
6465

6566
namespace boost { namespace decimal {
6667

@@ -119,6 +120,11 @@ constexpr auto fma(decimal64 x, decimal64 y, decimal64 z) noexcept -> decimal64
119120
return fmad64(x, y, z);
120121
}
121122

123+
constexpr auto fma(decimal128 x, decimal128 y, decimal128 z) noexcept -> decimal128
124+
{
125+
return fmad128(x, y, z);
126+
}
127+
122128
constexpr auto samequantum(decimal32 lhs, decimal32 rhs) noexcept -> bool
123129
{
124130
return samequantumd32(lhs, rhs);

include/boost/decimal/decimal128.hpp

Lines changed: 98 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,7 @@ class decimal128 final
535535
friend constexpr auto copysignd128(decimal128 mag, decimal128 sgn) noexcept -> decimal128;
536536
friend constexpr auto scalblnd128(decimal128 num, long exp) noexcept -> decimal128;
537537
friend constexpr auto scalbnd128(decimal128 num, int exp) noexcept -> decimal128;
538+
friend constexpr auto fmad128(decimal128 x, decimal128 y, decimal128 z) noexcept -> decimal128;
538539
};
539540

540541
#if !defined(BOOST_DECIMAL_DISABLE_IOSTREAM)
@@ -1802,11 +1803,21 @@ constexpr auto operator*(decimal128 lhs, decimal128 rhs) noexcept -> decimal128
18021803

18031804
auto lhs_sig {lhs.full_significand()};
18041805
auto lhs_exp {lhs.biased_exponent()};
1805-
detail::normalize<decimal128>(lhs_sig, lhs_exp);
1806+
1807+
while (lhs_sig % 10 == 0 && lhs_sig != 0)
1808+
{
1809+
lhs_sig /= 10;
1810+
++lhs_exp;
1811+
}
18061812

18071813
auto rhs_sig {rhs.full_significand()};
18081814
auto rhs_exp {rhs.biased_exponent()};
1809-
detail::normalize<decimal128>(rhs_sig, rhs_exp);
1815+
1816+
while (rhs_sig % 10 == 0 && rhs_sig != 0)
1817+
{
1818+
rhs_sig /= 10;
1819+
++rhs_exp;
1820+
}
18101821

18111822
const auto result {d128_mul_impl(lhs_sig, lhs_exp, lhs.isneg(),
18121823
rhs_sig, rhs_exp, rhs.isneg())};
@@ -1825,12 +1836,20 @@ constexpr auto operator*(decimal128 lhs, Integer rhs) noexcept
18251836

18261837
auto lhs_sig {lhs.full_significand()};
18271838
auto lhs_exp {lhs.biased_exponent()};
1828-
detail::normalize<decimal128>(lhs_sig, lhs_exp);
1839+
while (lhs_sig % 10 == 0 && lhs_sig != 0)
1840+
{
1841+
lhs_sig /= 10;
1842+
++lhs_exp;
1843+
}
18291844
auto lhs_components {detail::decimal128_components{lhs_sig, lhs_exp, lhs.isneg()}};
18301845

18311846
auto rhs_sig {static_cast<detail::uint128>(detail::make_positive_unsigned(rhs))};
18321847
std::int32_t rhs_exp {0};
1833-
detail::normalize<decimal128>(rhs_sig, rhs_exp);
1848+
while (rhs_sig % 10 == 0 && rhs_sig != 0)
1849+
{
1850+
rhs_sig /= 10;
1851+
++rhs_exp;
1852+
}
18341853
auto unsigned_sig_rhs {detail::make_positive_unsigned(rhs_sig)};
18351854
auto rhs_components {detail::decimal128_components{unsigned_sig_rhs, rhs_exp, (rhs < 0)}};
18361855

@@ -2260,6 +2279,81 @@ constexpr auto scalbnd128(decimal128 num, int expval) noexcept -> decimal128
22602279
return scalblnd128(num, static_cast<long>(expval));
22612280
}
22622281

2282+
constexpr auto fmad128(decimal128 x, decimal128 y, decimal128 z) noexcept -> decimal128
2283+
{
2284+
// First calculate x * y without rounding
2285+
constexpr decimal128 zero {0, 0};
2286+
2287+
const auto res {detail::check_non_finite(x, y)};
2288+
if (res != zero)
2289+
{
2290+
return res;
2291+
}
2292+
2293+
auto sig_lhs {x.full_significand()};
2294+
auto exp_lhs {x.biased_exponent()};
2295+
2296+
while (sig_lhs % 10 == 0 && sig_lhs != 0)
2297+
{
2298+
sig_lhs /= 10;
2299+
++exp_lhs;
2300+
}
2301+
2302+
auto sig_rhs {y.full_significand()};
2303+
auto exp_rhs {y.biased_exponent()};
2304+
2305+
while (sig_rhs % 10 == 0 && sig_rhs != 0)
2306+
{
2307+
sig_rhs /= 10;
2308+
++exp_rhs;
2309+
}
2310+
2311+
auto mul_result {d128_mul_impl(sig_lhs, exp_lhs, x.isneg(), sig_rhs, exp_rhs, y.isneg())};
2312+
const decimal128 dec_result {mul_result.sig, mul_result.exp, mul_result.sign};
2313+
2314+
const auto res_add {detail::check_non_finite(dec_result, z)};
2315+
if (res_add != zero)
2316+
{
2317+
return res_add;
2318+
}
2319+
2320+
bool lhs_bigger {dec_result > z};
2321+
if (dec_result.isneg() && z.isneg())
2322+
{
2323+
lhs_bigger = !lhs_bigger;
2324+
}
2325+
bool abs_lhs_bigger {abs(dec_result) > abs(z)};
2326+
2327+
detail::normalize<decimal128>(mul_result.sig, mul_result.exp);
2328+
2329+
auto sig_z {z.full_significand()};
2330+
auto exp_z {z.biased_exponent()};
2331+
detail::normalize<decimal128>(sig_z, exp_z);
2332+
detail::decimal128_components z_components {sig_z, exp_z, z.isneg()};
2333+
2334+
if (!lhs_bigger)
2335+
{
2336+
detail::swap(mul_result, z_components);
2337+
abs_lhs_bigger = !abs_lhs_bigger;
2338+
}
2339+
2340+
detail::decimal128_components result {};
2341+
2342+
if (!mul_result.sign && z_components.sign)
2343+
{
2344+
result = d128_sub_impl(mul_result.sig, mul_result.exp, mul_result.sign,
2345+
z_components.sig, z_components.exp, z_components.sign,
2346+
abs_lhs_bigger);
2347+
}
2348+
else
2349+
{
2350+
result = d128_add_impl(mul_result.sig, mul_result.exp, mul_result.sign,
2351+
z_components.sig, z_components.exp, z_components.sign);
2352+
}
2353+
2354+
return {result.sig, result.exp, result.sign};
2355+
}
2356+
22632357
} //namespace decimal
22642358
} //namespace boost
22652359

include/boost/decimal/detail/emulated256.hpp

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -420,17 +420,35 @@ constexpr uint256_t operator%(uint256_t lhs, std::uint64_t rhs) noexcept
420420
}
421421

422422
// Get the 256-bit result of multiplication of two 128-bit unsigned integers
423-
constexpr uint256_t umul256_impl(std::uint64_t a, std::uint64_t b, std::uint64_t c, std::uint64_t d) noexcept
423+
constexpr uint256_t umul256_impl(std::uint64_t a_high, std::uint64_t a_low, std::uint64_t b_high, std::uint64_t b_low) noexcept
424424
{
425-
const auto ac = umul128(a, c);
426-
const auto bc = umul128(b, c);
427-
const auto ad = umul128(a, d);
428-
const auto bd = umul128(b, d);
425+
const auto low_product {static_cast<uint128>(a_low) * b_low};
426+
const auto mid_product1 {static_cast<uint128>(a_low) * b_high};
427+
const auto mid_product2 {static_cast<uint128>(a_high) * b_low};
428+
const auto high_product {static_cast<uint128>(a_high) * b_high};
429429

430-
const auto intermediate = (bd >> 64) + static_cast<std::uint64_t>(ad) + static_cast<std::uint64_t>(bc);
430+
uint128 carry {};
431431

432-
return {ac + (intermediate >> 64) + (ad >> 64) + (bc >> 64),
433-
(intermediate << 64) + static_cast<std::uint64_t>(bd)};
432+
const auto mid_combined {mid_product1 + mid_product2};
433+
if (mid_combined < mid_product1)
434+
{
435+
carry = 1;
436+
}
437+
438+
const auto mid_combined_high {mid_combined >> 64};
439+
const auto mid_combined_low {mid_combined << 64};
440+
441+
const auto low_sum {low_product + mid_combined_low};
442+
if (low_sum < low_product)
443+
{
444+
carry += 1;
445+
}
446+
447+
uint256_t result {};
448+
result.low = low_sum;
449+
result.high = high_product + mid_combined_high + carry;
450+
451+
return result;
434452
}
435453

436454
template<typename T>

test/test_cmath.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,14 @@ void test_copysign()
333333
template <typename Dec>
334334
void test_fma()
335335
{
336+
if (!BOOST_TEST_EQ(Dec(1, -1) * Dec(1, 1), Dec(1, 0)))
337+
{
338+
std::cerr << std::setprecision(std::numeric_limits<Dec>::digits10)
339+
<< " Mul: " << Dec(1, -1) * Dec(1, 1)
340+
<< "\nActual: " << Dec(1, 0) << std::endl;
341+
}
342+
343+
BOOST_TEST_EQ(Dec(1, 0) + Dec(1, 0, true), Dec(0, 0));
336344
BOOST_TEST_EQ(fma(Dec(1, -1), Dec(1, 1), Dec(1, 0, true)), Dec(0, 0));
337345

338346
std::uniform_real_distribution<double> dist(-1e10, 1e10);
@@ -1382,15 +1390,14 @@ int main()
13821390
test_copysign<decimal32>();
13831391
test_copysign<decimal64>();
13841392

1385-
#if (defined(__clang__) || defined(_MSC_VER) || !defined(__GNUC__) || (defined(__GNUC__) && __GNUC__ > 6))
13861393
test_fma<decimal32>();
13871394
test_fma<decimal64>();
1395+
test_fma<decimal128>();
13881396

13891397
test_sin<decimal32>();
13901398
test_cos<decimal32>();
13911399
test_sin<decimal64>();
13921400
test_cos<decimal64>();
1393-
#endif
13941401

13951402
test_modf<decimal32>();
13961403
test_modf<decimal64>();

0 commit comments

Comments
 (0)