Skip to content

Commit 77c4d0f

Browse files
authored
fix the issue that decimal divide not round. (#6471)
close #4488, close #6393, close #6462
1 parent 2480a1d commit 77c4d0f

File tree

4 files changed

+314
-7
lines changed

4 files changed

+314
-7
lines changed

dbms/src/Functions/divide.cpp

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,36 @@ struct TiDBDivideFloatingImpl<A, B, false>
6060
using ResultType = typename NumberTraits::ResultOfFloatingPointDivision<A, B>::Type;
6161

6262
template <typename Result = ResultType>
63-
static Result apply(A a, B b)
63+
static Result apply(A x, B d)
6464
{
65-
return static_cast<Result>(a) / b;
65+
/// ref https://github.com/pingcap/tiflash/issues/6462
66+
/// For division of Decimal/Decimal or Int/Decimal or Decimal/Int, we should round the result to make compatible with TiDB.
67+
/// basically refer to https://stackoverflow.com/a/71634489
68+
if constexpr (std::is_integral_v<Result> || std::is_same_v<Result, Int256>)
69+
{
70+
/// 1. do division first, get the quotient and mod, todo:(perf) find a unified `divmod` function to speed up this.
71+
Result quotient = x / d;
72+
Result mod = x % d;
73+
/// 2. get the half of divisor, which is threshold to decide whether to round up or down.
74+
/// note: don't directly use bit operation here, it may cause unexpected result.
75+
Result half = (d / 2) + (d % 2);
76+
77+
/// 3. compare the abstract values of mod and half, if mod >= half, then round up.
78+
Result abs_m = mod < 0 ? -mod : mod;
79+
Result abs_h = half < 0 ? -half : half;
80+
if (abs_m >= abs_h)
81+
{
82+
/// 4. now we need to round up, i.e., add 1 to the quotient's absolute value.
83+
/// if the signs of dividend and divisor are the same, then the quotient should be positive, otherwise negative.
84+
if ((x < 0) == (d < 0)) // same_sign, i.e., quotient >= 0
85+
quotient = quotient + 1;
86+
else
87+
quotient = quotient - 1;
88+
}
89+
return quotient;
90+
}
91+
else
92+
return static_cast<Result>(x) / d;
6693
}
6794
template <typename Result = ResultType>
6895
static Result apply(A a, B b, UInt8 & res_null)
@@ -75,7 +102,7 @@ struct TiDBDivideFloatingImpl<A, B, false>
75102
res_null = 1;
76103
return static_cast<Result>(0);
77104
}
78-
return static_cast<Result>(a) / b;
105+
return apply<Result>(a, b);
79106
}
80107
};
81108

@@ -102,7 +129,7 @@ struct TiDBDivideFloatingImpl<A, B, true>
102129
res_null = 1;
103130
return static_cast<Result>(0);
104131
}
105-
return static_cast<Result>(a) / static_cast<Result>(b);
132+
return apply<Result>(a, b);
106133
}
107134
};
108135

@@ -332,4 +359,4 @@ void registerFunctionDivideIntegralOrZero(FunctionFactory & factory)
332359
factory.registerFunction<FunctionDivideIntegralOrZero>();
333360
}
334361

335-
} // namespace DB
362+
} // namespace DB

dbms/src/Functions/tests/gtest_arithmetic_functions.cpp

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
#include <Interpreters/Context.h>
2020
#include <TestUtils/FunctionTestUtils.h>
2121
#include <TestUtils/TiFlashTestBasic.h>
22+
#include <gtest/gtest.h>
2223

24+
#include <Functions/divide.cpp>
2325
#include <string>
2426
#include <unordered_map>
2527
#include <vector>
@@ -103,6 +105,141 @@ class TestBinaryArithmeticFunctions : public DB::tests::FunctionTest
103105
}
104106
};
105107

108+
template <typename TYPE>
109+
void doTiDBDivideDecimalRoundInternalTest()
110+
{
111+
auto apply = static_cast<TYPE (*)(TYPE, TYPE)>(&TiDBDivideFloatingImpl<TYPE, TYPE, false>::apply);
112+
113+
constexpr TYPE max = std::numeric_limits<TYPE>::max();
114+
// note: Int256's min is not equal to -max-1
115+
// according to https://www.boost.org/doc/libs/1_60_0/libs/multiprecision/doc/html/boost_multiprecision/tut/ints/cpp_int.html
116+
constexpr TYPE min = std::numeric_limits<TYPE>::min();
117+
118+
// clang-format off
119+
const std::vector<std::array<TYPE, 3>> cases = {
120+
{1, 2, 1}, {1, -2, -1}, {-1, 2, -1}, {-1, -2, 1},
121+
122+
{0, 3, 0}, {0, -3, 0}, {0, 3, 0}, {0, -3, 0},
123+
{1, 3, 0}, {1, -3, 0}, {-1, 3, 0}, {-1, -3, 0},
124+
{2, 3, 1}, {2, -3, -1}, {-2, 3, -1}, {-2, -3, 1},
125+
{3, 3, 1}, {3, -3, -1}, {-3, 3, -1}, {-3, -3, 1},
126+
{4, 3, 1}, {4, -3, -1}, {-4, 3, -1}, {-4, -3, 1},
127+
{5, 3, 2}, {5, -3, -2}, {-5, 3, -2}, {-5, -3, 2},
128+
129+
// ±max as divisor
130+
{0, max, 0}, {max/2-1, max, 0}, {max/2, max, 0}, {max/2+1, max, 1}, {max-1, max, 1}, {max, max, 1},
131+
{-1, max, 0}, {-max/2+1, max, 0}, {-max/2, max, 0}, {-max/2-1, max, -1}, {-max+1, max, -1}, {-max, max, -1}, {min, max, -1},
132+
{0, -max, 0}, {max/2-1, -max, 0}, {max/2, -max, 0}, {max/2+1, -max, -1}, {max-1, -max, -1}, {max, -max, -1},
133+
{-1, -max, 0}, {-max/2+1, -max, 0}, {-max/2, -max, 0}, {-max/2-1, -max, 1}, {-max+1, -max, 1}, {-max, -max, 1}, {min, -max, 1},
134+
135+
// ±max as dividend
136+
{max, 1, max}, {max, 2, max/2+1}, {max, max/2-1, 2}, {max, max/2, 2}, {max, max/2+1, 2}, {max, max-1, 1},
137+
{max, -1, -max}, {max, -2, -max/2-1}, {max, -max/2+1, -2}, {max, -max/2, -2}, {max, -max/2-1, -2}, {max, -max+1, -1},
138+
{-max, 1, -max}, {-max, 2, -max/2-1}, {-max, max/2+1, -2}, {-max, max/2, -2}, {-max, max/2-1, -2}, {-max, max-1, -1},
139+
{-max, -1, max}, {-max, -2, max/2+1}, {-max, -max/2-1, 2}, {-max, -max/2, 2}, {-max, -max/2+1, 2}, {-max, -max+1, 1},
140+
};
141+
// clang-format on
142+
143+
for (const auto & expect : cases)
144+
{
145+
std::array<TYPE, 3> actual = {expect[0], expect[1], apply(expect[0], expect[1])};
146+
ASSERT_EQ(expect, actual);
147+
}
148+
}
149+
150+
TEST_F(TestBinaryArithmeticFunctions, TiDBDivideDecimalRoundInternal)
151+
try
152+
{
153+
doTiDBDivideDecimalRoundInternalTest<Int32>();
154+
doTiDBDivideDecimalRoundInternalTest<Int64>();
155+
doTiDBDivideDecimalRoundInternalTest<Int128>();
156+
doTiDBDivideDecimalRoundInternalTest<Int256>();
157+
}
158+
CATCH
159+
160+
TEST_F(TestBinaryArithmeticFunctions, TiDBDivideDecimalRound)
161+
try
162+
{
163+
const String func_name = "tidbDivide";
164+
165+
// decimal32
166+
{
167+
// int and decimal
168+
ASSERT_COLUMN_EQ(
169+
createColumn<Nullable<Decimal64>>(std::make_tuple(18, 4), {DecimalField64(1, 4), DecimalField64(1, 4), DecimalField64(1, 4), DecimalField64(1, 4), DecimalField64(0, 4)}),
170+
executeFunction(
171+
func_name,
172+
createColumn<Int32>({1, 1, 1, 1, 1}),
173+
createColumn<Decimal32>(std::make_tuple(20, 4), {DecimalField32(100000000, 4), DecimalField32(100010000, 4), DecimalField32(199990000, 4), DecimalField32(200000000, 4), DecimalField32(200010000, 4)})));
174+
175+
// decimal and decimal
176+
ASSERT_COLUMN_EQ(
177+
createColumn<Nullable<Decimal128>>(std::make_tuple(26, 8), {DecimalField128(10000, 8), DecimalField128(9999, 8), DecimalField128(5000, 8), DecimalField128(5000, 8), DecimalField128(5000, 8)}),
178+
executeFunction(
179+
func_name,
180+
createColumn<Decimal32>(std::make_tuple(18, 4), {DecimalField32(10000, 4), DecimalField32(10000, 4), DecimalField32(10000, 4), DecimalField32(10000, 4), DecimalField32(10000, 4)}),
181+
createColumn<Decimal32>(std::make_tuple(18, 4), {DecimalField32(100000000, 4), DecimalField32(100010000, 4), DecimalField32(199990000, 4), DecimalField32(200000000, 4), DecimalField32(200010000, 4)})));
182+
}
183+
184+
// decimal64
185+
{
186+
// int and decimal
187+
ASSERT_COLUMN_EQ(
188+
createColumn<Nullable<Decimal64>>(std::make_tuple(18, 4), {DecimalField64(1, 4), DecimalField64(1, 4), DecimalField64(1, 4), DecimalField64(1, 4), DecimalField64(0, 4)}),
189+
executeFunction(
190+
func_name,
191+
createColumn<Int32>({1, 1, 1, 1, 1}),
192+
createColumn<Decimal64>(std::make_tuple(20, 4), {DecimalField64(100000000, 4), DecimalField64(100010000, 4), DecimalField64(199990000, 4), DecimalField64(200000000, 4), DecimalField64(200010000, 4)})));
193+
194+
// decimal and decimal
195+
ASSERT_COLUMN_EQ(
196+
createColumn<Nullable<Decimal128>>(std::make_tuple(26, 8), {DecimalField128(10000, 8), DecimalField128(9999, 8), DecimalField128(5000, 8), DecimalField128(5000, 8), DecimalField128(5000, 8)}),
197+
executeFunction(
198+
func_name,
199+
createColumn<Decimal64>(std::make_tuple(18, 4), {DecimalField64(10000, 4), DecimalField64(10000, 4), DecimalField64(10000, 4), DecimalField64(10000, 4), DecimalField64(10000, 4)}),
200+
createColumn<Decimal64>(std::make_tuple(18, 4), {DecimalField64(100000000, 4), DecimalField64(100010000, 4), DecimalField64(199990000, 4), DecimalField64(200000000, 4), DecimalField64(200010000, 4)})));
201+
}
202+
203+
// decimal128
204+
{
205+
// int and decimal
206+
ASSERT_COLUMN_EQ(
207+
createColumn<Nullable<Decimal64>>(std::make_tuple(18, 4), {DecimalField64(1, 4), DecimalField64(1, 4), DecimalField64(1, 4), DecimalField64(1, 4), DecimalField64(0, 4)}),
208+
executeFunction(
209+
func_name,
210+
createColumn<Int32>({1, 1, 1, 1, 1}),
211+
createColumn<Decimal128>(std::make_tuple(20, 4), {DecimalField128(100000000, 4), DecimalField128(100010000, 4), DecimalField128(199990000, 4), DecimalField128(200000000, 4), DecimalField128(200010000, 4)})));
212+
213+
// decimal and decimal
214+
ASSERT_COLUMN_EQ(
215+
createColumn<Nullable<Decimal128>>(std::make_tuple(26, 8), {DecimalField128(10000, 8), DecimalField128(9999, 8), DecimalField128(5000, 8), DecimalField128(5000, 8), DecimalField128(5000, 8)}),
216+
executeFunction(
217+
func_name,
218+
createColumn<Decimal128>(std::make_tuple(18, 4), {DecimalField128(10000, 4), DecimalField128(10000, 4), DecimalField128(10000, 4), DecimalField128(10000, 4), DecimalField128(10000, 4)}),
219+
createColumn<Decimal128>(std::make_tuple(18, 4), {DecimalField128(100000000, 4), DecimalField128(100010000, 4), DecimalField128(199990000, 4), DecimalField128(200000000, 4), DecimalField128(200010000, 4)})));
220+
}
221+
222+
// decimal256
223+
{
224+
// int and decimal
225+
ASSERT_COLUMN_EQ(
226+
createColumn<Nullable<Decimal64>>(std::make_tuple(18, 4), {DecimalField64(1, 4), DecimalField64(1, 4), DecimalField64(1, 4), DecimalField64(1, 4), DecimalField64(0, 4)}),
227+
executeFunction(
228+
func_name,
229+
createColumn<Int32>({1, 1, 1, 1, 1}),
230+
createColumn<Decimal256>(std::make_tuple(20, 4), {DecimalField256(Int256(100000000), 4), DecimalField256(Int256(100010000), 4), DecimalField256(Int256(199990000), 4), DecimalField256(Int256(200000000), 4), DecimalField256(Int256(200010000), 4)})));
231+
232+
// decimal and decimal
233+
ASSERT_COLUMN_EQ(
234+
createColumn<Nullable<Decimal128>>(std::make_tuple(26, 8), {DecimalField128(10000, 8), DecimalField128(9999, 8), DecimalField128(5000, 8), DecimalField128(5000, 8), DecimalField128(5000, 8)}),
235+
executeFunction(
236+
func_name,
237+
createColumn<Decimal256>(std::make_tuple(18, 4), {DecimalField256(Int256(10000), 4), DecimalField256(Int256(10000), 4), DecimalField256(Int256(10000), 4), DecimalField256(Int256(10000), 4), DecimalField256(Int256(10000), 4)}),
238+
createColumn<Decimal256>(std::make_tuple(18, 4), {DecimalField256(Int256(100000000), 4), DecimalField256(Int256(100010000), 4), DecimalField256(Int256(199990000), 4), DecimalField256(Int256(200000000), 4), DecimalField256(Int256(200010000), 4)})));
239+
}
240+
}
241+
CATCH
242+
106243
TEST_F(TestBinaryArithmeticFunctions, TiDBDivideDecimal)
107244
try
108245
{
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright 2023 PingCAP, Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# decimal / decimal
16+
mysql> drop table if exists test.t;
17+
mysql> create table test.t(a decimal(4,0), b decimal(40, 20));
18+
mysql> alter table test.t set tiflash replica 1
19+
mysql> insert into test.t values (1, 10000), (1, 10001), (1, 20000), (1, 20001);
20+
func> wait_table test t
21+
mysql> set tidb_enforce_mpp=1; select a, b, a/b from test.t order by b;
22+
+------+----------------------------+--------+
23+
| a | b | a/b |
24+
+------+----------------------------+--------+
25+
| 1 | 10000.00000000000000000000 | 0.0001 |
26+
| 1 | 10001.00000000000000000000 | 0.0001 |
27+
| 1 | 20000.00000000000000000000 | 0.0001 |
28+
| 1 | 20001.00000000000000000000 | 0.0000 |
29+
+------+----------------------------+--------+
30+
31+
# int / decimal
32+
mysql> drop table if exists test.t;
33+
mysql> create table test.t(a int, b decimal(40, 20));
34+
mysql> alter table test.t set tiflash replica 1
35+
mysql> insert into test.t values (1, 10000), (1, 10001), (1, 20000), (1, 20001);
36+
func> wait_table test t
37+
mysql> set tidb_enforce_mpp=1; select a, b, a/b from test.t order by b;
38+
+------+----------------------------+--------+
39+
| a | b | a/b |
40+
+------+----------------------------+--------+
41+
| 1 | 10000.00000000000000000000 | 0.0001 |
42+
| 1 | 10001.00000000000000000000 | 0.0001 |
43+
| 1 | 20000.00000000000000000000 | 0.0001 |
44+
| 1 | 20001.00000000000000000000 | 0.0000 |
45+
+------+----------------------------+--------+
46+
47+
# decimal / int
48+
mysql> drop table if exists test.t;
49+
mysql> create table test.t(a int, b decimal(40, 20));
50+
mysql> alter table test.t set tiflash replica 1
51+
mysql> insert into test.t values (1, 10000), (1, 10001), (1, 20000), (1, 20001);
52+
func> wait_table test t
53+
mysql> set tidb_enforce_mpp=1; select a, b, a/b from test.t order by b;
54+
+------+----------------------------+--------+
55+
| a | b | a/b |
56+
+------+----------------------------+--------+
57+
| 1 | 10000.00000000000000000000 | 0.0001 |
58+
| 1 | 10001.00000000000000000000 | 0.0001 |
59+
| 1 | 20000.00000000000000000000 | 0.0001 |
60+
| 1 | 20001.00000000000000000000 | 0.0000 |
61+
+------+----------------------------+--------+
62+
63+
# int / int
64+
mysql> drop table if exists test.t;
65+
mysql> create table test.t(a int, b int);
66+
mysql> alter table test.t set tiflash replica 1
67+
mysql> insert into test.t values (1, 10000), (1, 10001), (1, 20000), (1, 20001);
68+
func> wait_table test t
69+
mysql> set tidb_enforce_mpp=1; select a, b, a/b from test.t order by b;
70+
+------+-------+--------+
71+
| a | b | a/b |
72+
+------+-------+--------+
73+
| 1 | 10000 | 0.0001 |
74+
| 1 | 10001 | 0.0001 |
75+
| 1 | 20000 | 0.0001 |
76+
| 1 | 20001 | 0.0000 |
77+
+------+-------+--------+
78+
79+
mysql> drop table if exists test.t;
80+
mysql> create table test.t(a decimal(10,0), b decimal(10,0));
81+
mysql> alter table test.t set tiflash replica 1
82+
mysql> insert into test.t values (2147483647, 1), (2147483647, 1073741823), (2147483647, 1073741824), (2147483647, 2147483646), (2147483647, 2147483647);
83+
mysql> insert into test.t values (-2147483647, 1), (-2147483647, 1073741823), (-2147483647, 1073741824), (-2147483647, 2147483646), (-2147483647, 2147483647);
84+
mysql> insert into test.t values (-2147483647, -1), (-2147483647, -1073741823), (-2147483647, -1073741824), (-2147483647, -2147483646), (-2147483647, -2147483647);
85+
mysql> insert into test.t values (2147483647, -1), (2147483647, -1073741823), (2147483647, -1073741824), (2147483647, -2147483646), (2147483647, -2147483647);
86+
func> wait_table test t
87+
mysql> set tidb_enforce_mpp=1; select b, a, b/(a*10000) from test.t where a/b order by b;
88+
+-------------+-------------+-------------+
89+
| b | a | b/(a*10000) |
90+
+-------------+-------------+-------------+
91+
| -2147483647 | 2147483647 | -0.0001 |
92+
| -2147483647 | -2147483647 | 0.0001 |
93+
| -2147483646 | 2147483647 | -0.0001 |
94+
| -2147483646 | -2147483647 | 0.0001 |
95+
| -1073741824 | 2147483647 | -0.0001 |
96+
| -1073741824 | -2147483647 | 0.0001 |
97+
| -1073741823 | -2147483647 | 0.0000 |
98+
| -1073741823 | 2147483647 | 0.0000 |
99+
| -1 | 2147483647 | 0.0000 |
100+
| -1 | -2147483647 | 0.0000 |
101+
| 1 | -2147483647 | 0.0000 |
102+
| 1 | 2147483647 | 0.0000 |
103+
| 1073741823 | -2147483647 | 0.0000 |
104+
| 1073741823 | 2147483647 | 0.0000 |
105+
| 1073741824 | -2147483647 | -0.0001 |
106+
| 1073741824 | 2147483647 | 0.0001 |
107+
| 2147483646 | -2147483647 | -0.0001 |
108+
| 2147483646 | 2147483647 | 0.0001 |
109+
| 2147483647 | -2147483647 | -0.0001 |
110+
| 2147483647 | 2147483647 | 0.0001 |
111+
+-------------+-------------+-------------+
112+
mysql> delete from test.t;
113+
mysql> insert into test.t values (2147483647, 9999999999), (9999999999, 2147483647), (1, 9999999999), (4999999999, 9999999999), (5000000000, 9999999999);
114+
mysql> insert into test.t values (-2147483647, 9999999999), (-9999999999, 2147483647), (-1, 9999999999), (-4999999999, 9999999999), (-5000000000, 9999999999);
115+
mysql> insert into test.t values (-2147483647, -9999999999), (-9999999999, -2147483647), (-1, -9999999999), (-4999999999, -9999999999), (-5000000000, -9999999999);
116+
mysql> insert into test.t values (2147483647, -9999999999), (9999999999, -2147483647), (1, -9999999999), (4999999999, -9999999999), (5000000000, -9999999999);
117+
mysql> set tidb_enforce_mpp=1; select b, a, b/(a*10000) from test.t where a/b order by b;
118+
+-------------+-------------+-------------+
119+
| b | a | b/(a*10000) |
120+
+-------------+-------------+-------------+
121+
| -9999999999 | 2147483647 | -0.0005 |
122+
| -9999999999 | -4999999999 | 0.0002 |
123+
| -9999999999 | 5000000000 | -0.0002 |
124+
| -9999999999 | 4999999999 | -0.0002 |
125+
| -9999999999 | -2147483647 | 0.0005 |
126+
| -9999999999 | -5000000000 | 0.0002 |
127+
| -2147483647 | -9999999999 | 0.0000 |
128+
| -2147483647 | 9999999999 | 0.0000 |
129+
| 2147483647 | 9999999999 | 0.0000 |
130+
| 2147483647 | -9999999999 | 0.0000 |
131+
| 9999999999 | -4999999999 | -0.0002 |
132+
| 9999999999 | -2147483647 | -0.0005 |
133+
| 9999999999 | -5000000000 | -0.0002 |
134+
| 9999999999 | 2147483647 | 0.0005 |
135+
| 9999999999 | 5000000000 | 0.0002 |
136+
| 9999999999 | 4999999999 | 0.0002 |
137+
+-------------+-------------+-------------+

tests/tidb-ci/fullstack-test-dt/issue_1425.test

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,20 @@ mysql> drop table if exists test.t;
1616

1717
mysql> create table test.t (id int, value decimal(7,4), c1 int, c2 int);
1818

19-
mysql> insert into test.t values(1,1.9286,54,28);
19+
mysql> insert into test.t values (1,1.9285,54,28), (1,1.9286,54,28);
2020

2121
mysql> alter table test.t set tiflash replica 1;
2222

2323
func> wait_table test t
2424

25+
# note: ref to https://github.com/pingcap/tiflash/issues/1682,
26+
# The precision of tiflash results is different from that of tidb, which is a compatibility issue
2527
mysql> use test; set session tidb_isolation_read_engines='tiflash'; select * from t where value = 54/28;
26-
2728
mysql> use test; set session tidb_isolation_read_engines='tiflash'; select * from t where value = c1/c2;
29+
+------+--------+------+------+
30+
| id | value | c1 | c2 |
31+
+------+--------+------+------+
32+
| 1 | 1.9286 | 54 | 28 |
33+
+------+--------+------+------+
2834

2935
mysql> drop table if exists test.t;

0 commit comments

Comments
 (0)