Skip to content

Commit 3213913

Browse files
committed
Implementation of Math.fsum.
1 parent 8db5cad commit 3213913

File tree

3 files changed

+230
-13
lines changed

3 files changed

+230
-13
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_math.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,6 +1220,7 @@ class II(int):
12201220
self.assertEqual(math.ldexp(FF(10), II(12)), 40960.0)
12211221
self.assertRaises(TypeError, math.ldexp, 'Hello', 1000000)
12221222
self.assertRaises(TypeError, math.ldexp, 1, 'Hello')
1223+
self.assertEqual(math.ldexp(7589167167882033, -48), 26.962138008038156)
12231224

12241225
def test_trunc(self):
12251226
self.assertEqual(math.trunc(1), 1)
@@ -1293,3 +1294,95 @@ def testmodf(name, result, expected):
12931294
testmodf('modf(MyFloat())', math.modf(MyFloat()), (0.6, 0.0))
12941295
self.assertRaises(TypeError, math.modf, 'ahoj')
12951296
testmodf('modf(BIG_INT)', math.modf(BIG_INT), (0.0, 9.999992432902008e+30))
1297+
1298+
def testFsum(self):
1299+
# math.fsum relies on exact rounding for correct operation.
1300+
# There's a known problem with IA32 floating-point that causes
1301+
# inexact rounding in some situations, and will cause the
1302+
# math.fsum tests below to fail; see issue #2937. On non IEEE
1303+
# 754 platforms, and on IEEE 754 platforms that exhibit the
1304+
# problem described in issue #2937, we simply skip the whole
1305+
# test.
1306+
1307+
# Python version of math.fsum, for comparison. Uses a
1308+
# different algorithm based on frexp, ldexp and integer
1309+
# arithmetic.
1310+
1311+
from sys import float_info
1312+
mant_dig = float_info.mant_dig
1313+
etiny = float_info.min_exp - mant_dig
1314+
1315+
def msum(iterable):
1316+
"""Full precision summation. Compute sum(iterable) without any
1317+
intermediate accumulation of error. Based on the 'lsum' function
1318+
at http://code.activestate.com/recipes/393090/
1319+
1320+
"""
1321+
tmant, texp = 0, 0
1322+
for x in iterable:
1323+
mant, exp = math.frexp(x)
1324+
mant, exp = int(math.ldexp(mant, mant_dig)), exp - mant_dig
1325+
if texp > exp:
1326+
tmant <<= texp-exp
1327+
texp = exp
1328+
else:
1329+
mant <<= exp-texp
1330+
tmant += mant
1331+
# Round tmant * 2**texp to a float. The original recipe
1332+
# used float(str(tmant)) * 2.0**texp for this, but that's
1333+
# a little unsafe because str -> float conversion can't be
1334+
# relied upon to do correct rounding on all platforms.
1335+
tail = max(len(bin(abs(tmant)))-2 - mant_dig, etiny - texp)
1336+
if tail > 0:
1337+
h = 1 << (tail-1)
1338+
tmant = tmant // (2*h) + bool(tmant & h and tmant & 3*h-1)
1339+
texp += tail
1340+
return math.ldexp(tmant, texp)
1341+
1342+
test_values = [
1343+
([], 0.0),
1344+
([0.0], 0.0),
1345+
([1e100, 1.0, -1e100, 1e-100, 1e50, -1.0, -1e50], 1e-100),
1346+
([2.0**53, -0.5, -2.0**-54], 2.0**53-1.0),
1347+
([2.0**53, 1.0, 2.0**-100], 2.0**53+2.0),
1348+
([2.0**53+10.0, 1.0, 2.0**-100], 2.0**53+12.0),
1349+
([2.0**53-4.0, 0.5, 2.0**-54], 2.0**53-3.0),
1350+
([1./n for n in range(1, 1001)],
1351+
float.fromhex('0x1.df11f45f4e61ap+2')),
1352+
([(-1.)**n/n for n in range(1, 1001)],
1353+
float.fromhex('-0x1.62a2af1bd3624p-1')),
1354+
([1.7**(i+1)-1.7**i for i in range(1000)] + [-1.7**1000], -1.0),
1355+
([1e16, 1., 1e-16], 10000000000000002.0),
1356+
([1e16-2., 1.-2.**-53, -(1e16-2.), -(1.-2.**-53)], 0.0),
1357+
# exercise code for resizing partials array
1358+
([2.**n - 2.**(n+50) + 2.**(n+52) for n in range(-1074, 972, 2)] +
1359+
[-2.**1022],
1360+
float.fromhex('0x1.5555555555555p+970')),
1361+
]
1362+
1363+
for i, (vals, expected) in enumerate(test_values):
1364+
try:
1365+
actual = math.fsum(vals)
1366+
except OverflowError:
1367+
self.fail("test %d failed: got OverflowError, expected %r "
1368+
"for math.fsum(%.100r)" % (i, expected, vals))
1369+
except ValueError:
1370+
self.fail("test %d failed: got ValueError, expected %r "
1371+
"for math.fsum(%.100r)" % (i, expected, vals))
1372+
self.assertEqual(actual, expected)
1373+
1374+
from random import random, gauss, shuffle
1375+
for j in range(1000):
1376+
vals = [7, 1e100, -7, -1e100, -9e-20, 8e-20] * 10
1377+
s = 0
1378+
for i in range(200):
1379+
v = gauss(0, random()) ** 7 - s
1380+
s += v
1381+
vals.append(v)
1382+
shuffle(vals)
1383+
1384+
s = msum(vals)
1385+
self.assertEqual(msum(vals), math.fsum(vals))
1386+
1387+
self.assertRaises(ValueError, math.fsum, [1., 2, INF, NINF])
1388+
self.assertEqual(math.fsum([1., 2, INF, INF]), INF)

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/BuiltinFunctions.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -198,50 +198,50 @@ public Object absObject(Object object,
198198
return result;
199199
}
200200
}
201-
201+
202202
// bin(object)
203203
@Builtin(name = BIN, fixedNumOfArguments = 1)
204204
@TypeSystemReference(PythonArithmeticTypes.class)
205205
@GenerateNodeFactory
206206
public abstract static class BinNode extends PythonUnaryBuiltinNode {
207-
207+
208208
public abstract String executeObject(Object x);
209-
209+
210210
private String buildString(boolean isNegative, String number) {
211211
StringBuilder sb = new StringBuilder();
212-
if(isNegative) {
212+
if (isNegative) {
213213
sb.append('-');
214214
}
215215
sb.append("0b");
216216
sb.append(number);
217217
return sb.toString();
218218
}
219-
219+
220220
@Specialization
221221
public String doL(long x) {
222222
return buildString(x < 0, Long.toBinaryString(Math.abs(x)));
223223
}
224-
224+
225225
@Specialization
226226
public String doD(double x) {
227227
throw raise(TypeError, "'%p' object cannot be interpreted as an integer", x);
228228
}
229-
229+
230230
@Specialization
231231
@TruffleBoundary
232232
public String doPI(PInt x) {
233233
BigInteger value = x.getValue();
234234
return buildString(value.compareTo(BigInteger.ZERO) == -1, value.abs().toString(2));
235235
}
236-
236+
237237
@Specialization
238238
public String doO(Object x,
239-
@Cached("create()") MathModuleBuiltins.ConvertToIntNode toIntNode,
240-
@Cached("create()") BinNode recursiveNode) {
239+
@Cached("create()") MathModuleBuiltins.ConvertToIntNode toIntNode,
240+
@Cached("create()") BinNode recursiveNode) {
241241
Object value = toIntNode.execute(x);
242242
return recursiveNode.executeObject(value);
243243
}
244-
244+
245245
protected BinNode create() {
246246
return BuiltinFunctionsFactory.BinNodeFactory.create();
247247
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/MathModuleBuiltins.java

Lines changed: 126 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,13 @@
4646
import com.oracle.graal.python.nodes.PGuards;
4747
import com.oracle.graal.python.nodes.SpecialMethodNames;
4848
import com.oracle.graal.python.nodes.call.special.LookupAndCallUnaryNode;
49+
import com.oracle.graal.python.nodes.control.GetIteratorNode;
4950
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
5051
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
5152
import com.oracle.graal.python.nodes.function.builtins.PythonBinaryBuiltinNode;
5253
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
5354
import com.oracle.graal.python.nodes.truffle.PythonArithmeticTypes;
55+
import com.oracle.graal.python.runtime.exception.PException;
5456
import com.oracle.graal.python.runtime.exception.PythonErrorType;
5557
import static com.oracle.graal.python.runtime.exception.PythonErrorType.NotImplementedError;
5658
import com.oracle.graal.python.runtime.object.PythonObjectFactory;
@@ -65,6 +67,7 @@
6567
import com.oracle.truffle.api.dsl.Specialization;
6668
import com.oracle.truffle.api.dsl.TypeSystemReference;
6769
import com.oracle.truffle.api.profiles.ConditionProfile;
70+
import java.util.Arrays;
6871

6972
@CoreFunctions(defineModule = "math")
7073
public class MathModuleBuiltins extends PythonBuiltins {
@@ -968,7 +971,7 @@ public double ldexpLD(long mantissa, double exp) {
968971

969972
@Specialization
970973
public double ldexpLL(long mantissa, long exp) {
971-
return exceptInfinity(Math.scalb(mantissa, makeInt(exp)), mantissa);
974+
return exceptInfinity(Math.scalb((double) mantissa, makeInt(exp)), mantissa);
972975
}
973976

974977
@Specialization
@@ -980,7 +983,7 @@ public double ldexpDPI(double mantissa, PInt exp) {
980983
@Specialization
981984
@TruffleBoundary
982985
public double ldexpLPI(long mantissa, PInt exp) {
983-
return exceptInfinity(Math.scalb(mantissa, makeInt(exp)), mantissa);
986+
return exceptInfinity(Math.scalb((double) mantissa, makeInt(exp)), mantissa);
984987
}
985988

986989
@Specialization
@@ -1048,6 +1051,127 @@ public PTuple frexpO(Object value,
10481051
}
10491052
}
10501053

1054+
@Builtin(name = "fsum", fixedNumOfArguments = 1)
1055+
@ImportStatic(PGuards.class)
1056+
@GenerateNodeFactory
1057+
public abstract static class FsumNode extends PythonUnaryBuiltinNode {
1058+
1059+
/*
1060+
* This implementation is taken from CPython. The performance is not good. Should be faster.
1061+
* It can be easily replace with much simpler code based on BigDecimal:
1062+
*
1063+
* BigDecimal result = BigDecimal.ZERO;
1064+
*
1065+
* in cycle just: result = result.add(BigDecimal.valueof(x); ... The current implementation
1066+
* is little bit faster. The testFSum in test_math.py takes in different implementations:
1067+
* CPython ~0.6s CurrentImpl: ~14.3s Using BigDecimal: ~15.1
1068+
*/
1069+
@Specialization
1070+
@TruffleBoundary
1071+
public double doIt(Object iterable,
1072+
@Cached("create()") GetIteratorNode getIterator,
1073+
@Cached("create(__NEXT__)") LookupAndCallUnaryNode next,
1074+
@Cached("create()") ConvertToFloatNode toFloat,
1075+
@Cached("createBinaryProfile()") ConditionProfile stopProfile) {
1076+
Object iterator = getIterator.executeWith(iterable);
1077+
double x, y, t, hi, lo = 0, yr, inf_sum = 0, special_sum = 0, sum;
1078+
double xsave;
1079+
int i, j, n = 0, arayLength = 32;
1080+
double[] p = new double[arayLength];
1081+
while (true) {
1082+
try {
1083+
x = toFloat.execute(next.executeObject(iterator));
1084+
} catch (PException e) {
1085+
e.expectStopIteration(getCore(), stopProfile);
1086+
break;
1087+
}
1088+
xsave = x;
1089+
for (i = j = 0; j < n; j++) { /* for y in partials */
1090+
y = p[j];
1091+
if (Math.abs(x) < Math.abs(y)) {
1092+
t = x;
1093+
x = y;
1094+
y = t;
1095+
}
1096+
hi = x + y;
1097+
yr = hi - x;
1098+
lo = y - yr;
1099+
if (lo != 0.0) {
1100+
p[i++] = lo;
1101+
}
1102+
x = hi;
1103+
}
1104+
1105+
n = i;
1106+
if (x != 0.0) {
1107+
if (!Double.isFinite(x)) {
1108+
/*
1109+
* a nonfinite x could arise either as a result of intermediate overflow, or
1110+
* as a result of a nan or inf in the summands
1111+
*/
1112+
if (Double.isFinite(xsave)) {
1113+
throw raise(OverflowError, "intermediate overflow in fsum");
1114+
}
1115+
if (Double.isInfinite(xsave)) {
1116+
inf_sum += xsave;
1117+
}
1118+
special_sum += xsave;
1119+
/* reset partials */
1120+
n = 0;
1121+
} else if (n >= arayLength) {
1122+
arayLength += arayLength;
1123+
p = Arrays.copyOf(p, arayLength);
1124+
} else {
1125+
p[n++] = x;
1126+
}
1127+
}
1128+
}
1129+
1130+
if (special_sum != 0.0) {
1131+
if (Double.isNaN(inf_sum)) {
1132+
throw raise(ValueError, "-inf + inf in fsum");
1133+
} else {
1134+
sum = special_sum;
1135+
return sum;
1136+
}
1137+
}
1138+
1139+
hi = 0.0;
1140+
if (n > 0) {
1141+
hi = p[--n];
1142+
/*
1143+
* sum_exact(ps, hi) from the top, stop when the sum becomes inexact.
1144+
*/
1145+
while (n > 0) {
1146+
x = hi;
1147+
y = p[--n];
1148+
assert (Math.abs(y) < Math.abs(x));
1149+
hi = x + y;
1150+
yr = hi - x;
1151+
lo = y - yr;
1152+
if (lo != 0.0)
1153+
break;
1154+
}
1155+
/*
1156+
* Make half-even rounding work across multiple partials. Needed so that sum([1e-16,
1157+
* 1, 1e16]) will round-up the last digit to two instead of down to zero (the 1e-16
1158+
* makes the 1 slightly closer to two). With a potential 1 ULP rounding error
1159+
* fixed-up, math.fsum() can guarantee commutativity.
1160+
*/
1161+
if (n > 0 && ((lo < 0.0 && p[n - 1] < 0.0) ||
1162+
(lo > 0.0 && p[n - 1] > 0.0))) {
1163+
y = lo * 2.0;
1164+
x = hi + y;
1165+
yr = x - hi;
1166+
if (y == yr) {
1167+
hi = x;
1168+
}
1169+
}
1170+
}
1171+
return hi;
1172+
}
1173+
}
1174+
10511175
@Builtin(name = "gcd", fixedNumOfArguments = 2)
10521176
@TypeSystemReference(PythonArithmeticTypes.class)
10531177
@GenerateNodeFactory

0 commit comments

Comments
 (0)