@@ -144,22 +144,141 @@ protected void raiseDivisionByZero(boolean cond) {
144
144
@ TypeSystemReference (PythonArithmeticTypes .class )
145
145
abstract static class RoundNode extends PythonBinaryBuiltinNode {
146
146
@ SuppressWarnings ("unused" )
147
- @ Specialization ( guards = "isPNone(n) || isInteger(n)" )
148
- public int roundInt (int arg , Object n ) {
147
+ @ Specialization
148
+ public int roundIntNone (int arg , PNone n ) {
149
149
return arg ;
150
150
}
151
151
152
152
@ SuppressWarnings ("unused" )
153
- @ Specialization ( guards = "isPNone(n) || isInteger(n)" )
154
- public long roundLong (long arg , Object n ) {
153
+ @ Specialization
154
+ public long roundLongNone (long arg , PNone n ) {
155
155
return arg ;
156
156
}
157
157
158
158
@ SuppressWarnings ("unused" )
159
- @ Specialization ( guards = "isPNone(n) || isInteger(n)" )
160
- public PInt roundPInt (PInt arg , Object n ) {
159
+ @ Specialization
160
+ public PInt roundPIntNone (PInt arg , PNone n ) {
161
161
return factory ().createInt (arg .getValue ());
162
162
}
163
+
164
+ @ Specialization
165
+ public Object roundLongInt (long arg , int n ) {
166
+ if (n >= 0 ) {
167
+ return arg ;
168
+ }
169
+ return makeInt (op (arg , n ));
170
+ }
171
+
172
+ @ Specialization
173
+ public Object roundPIntInt (PInt arg , int n ) {
174
+ if (n >= 0 ) {
175
+ return arg ;
176
+ }
177
+ return makeInt (op (arg .getValue (), n ));
178
+ }
179
+
180
+ @ Specialization
181
+ public Object roundLongLong (long arg , long n ) {
182
+ if (n >= 0 ) {
183
+ return arg ;
184
+ }
185
+ if (n < Integer .MIN_VALUE ) {
186
+ return 0 ;
187
+ }
188
+ return makeInt (op (arg , (int ) n ));
189
+ }
190
+
191
+ @ Specialization
192
+ public Object roundPIntLong (PInt arg , long n ) {
193
+ if (n >= 0 ) {
194
+ return arg ;
195
+ }
196
+ if (n < Integer .MIN_VALUE ) {
197
+ return 0 ;
198
+ }
199
+ return makeInt (op (arg .getValue (), (int ) n ));
200
+ }
201
+
202
+ @ Specialization
203
+ public Object roundPIntLong (long arg , PInt n ) {
204
+ if (n .isZeroOrPositive ()) {
205
+ return arg ;
206
+ }
207
+ try {
208
+ return makeInt (op (arg , n .intValueExact ()));
209
+ } catch (ArithmeticException e ) {
210
+ // n is < -2^31, max. number of base-10 digits in BigInteger is 2^31 * log10(2)
211
+ return 0 ;
212
+ }
213
+ }
214
+
215
+ @ Specialization
216
+ public Object roundPIntPInt (PInt arg , PInt n ) {
217
+ if (n .isZeroOrPositive ()) {
218
+ return arg ;
219
+ }
220
+ try {
221
+ return makeInt (op (arg .getValue (), n .intValueExact ()));
222
+ } catch (ArithmeticException e ) {
223
+ // n is < -2^31, max. number of base-10 digits in BigInteger is 2^31 * log10(2)
224
+ return 0 ;
225
+ }
226
+ }
227
+
228
+ @ Specialization (guards = {"!isInteger(n)" })
229
+ @ SuppressWarnings ("unused" )
230
+ public Object roundPIntPInt (Object arg , Object n ) {
231
+ throw raise (PythonErrorType .TypeError , ErrorMessages .OBJ_CANNOT_BE_INTERPRETED_AS_INTEGER , n );
232
+ }
233
+
234
+ private Object makeInt (BigDecimal d ) {
235
+ try {
236
+ return intValueExact (d );
237
+ } catch (ArithmeticException e ) {
238
+ // does not fit int, so try long
239
+ }
240
+ try {
241
+ return longValueExact (d );
242
+ } catch (ArithmeticException e ) {
243
+ // does not fit long, try BigInteger
244
+ }
245
+ try {
246
+ return factory ().createInt (d .toBigIntegerExact ());
247
+ } catch (ArithmeticException e ) {
248
+ // has non-zero fractional part, which should not happen
249
+ throw CompilerDirectives .shouldNotReachHere ("non-integer produced after rounding an integer" , e );
250
+ }
251
+ }
252
+
253
+ @ TruffleBoundary
254
+ private static int intValueExact (BigDecimal d ) {
255
+ return d .intValueExact ();
256
+ }
257
+
258
+ @ TruffleBoundary
259
+ private static long longValueExact (BigDecimal d ) {
260
+ return d .longValueExact ();
261
+ }
262
+
263
+ @ TruffleBoundary
264
+ private static BigDecimal op (long arg , int n ) {
265
+ try {
266
+ return new BigDecimal (arg ).setScale (n , RoundingMode .HALF_EVEN );
267
+ } catch (ArithmeticException e ) {
268
+ // -n exceeds max. number of base-10 digits in BigInteger
269
+ return BigDecimal .ZERO ;
270
+ }
271
+ }
272
+
273
+ @ TruffleBoundary
274
+ private static BigDecimal op (BigInteger arg , int n ) {
275
+ try {
276
+ return new BigDecimal (arg ).setScale (n , RoundingMode .HALF_EVEN );
277
+ } catch (ArithmeticException e ) {
278
+ // -n exceeds max. number of base-10 digits in BigInteger
279
+ return BigDecimal .ZERO ;
280
+ }
281
+ }
163
282
}
164
283
165
284
@ Builtin (name = SpecialMethodNames .__RADD__ , minNumOfPositionalArgs = 2 )
@@ -704,7 +823,9 @@ static long doLLFast(long left, long right, @SuppressWarnings("unused") PNone no
704
823
result = Math .multiplyExact (result , base );
705
824
}
706
825
exponent >>= 1 ;
707
- base = Math .multiplyExact (base , base );
826
+ if (exponent != 0 ) { // prevent overflow in last iteration
827
+ base = Math .multiplyExact (base , base );
828
+ }
708
829
}
709
830
return result ;
710
831
}
@@ -1313,15 +1434,13 @@ long doLL(long left, long right) {
1313
1434
}
1314
1435
1315
1436
@ Specialization
1316
- PInt doIPi (int left , PInt right ) {
1317
- raiseNegativeShiftCount (!right .isZeroOrPositive ());
1318
- return factory ().createInt (op (PInt .longToBigInteger (left ), right .intValue ()));
1437
+ Object doIPi (int left , PInt right ) {
1438
+ return doHugeShift (PInt .longToBigInteger (left ), right );
1319
1439
}
1320
1440
1321
1441
@ Specialization
1322
- PInt doLPi (long left , PInt right ) {
1323
- raiseNegativeShiftCount (!right .isZeroOrPositive ());
1324
- return factory ().createInt (op (PInt .longToBigInteger (left ), right .intValue ()));
1442
+ Object doLPi (long left , PInt right ) {
1443
+ return doHugeShift (PInt .longToBigInteger (left ), right );
1325
1444
}
1326
1445
1327
1446
@ Specialization
@@ -1331,15 +1450,20 @@ PInt doPiI(PInt left, int right) {
1331
1450
}
1332
1451
1333
1452
@ Specialization
1334
- PInt doPiL (PInt left , long right ) {
1453
+ Object doPiL (PInt left , long right ) {
1335
1454
raiseNegativeShiftCount (right < 0 );
1336
- return factory ().createInt (op (left .getValue (), (int ) right ));
1455
+ int rightI = (int ) right ;
1456
+ if (rightI == right ) {
1457
+ return factory ().createInt (op (left .getValue (), rightI ));
1458
+ }
1459
+ // right is >= 2**31, BigInteger's bitLength is at most 2**31-1
1460
+ // therefore the result of shifting right is just the sign bit
1461
+ return left .isNegative () ? -1 : 0 ;
1337
1462
}
1338
1463
1339
1464
@ Specialization
1340
- PInt doPInt (PInt left , PInt right ) {
1341
- raiseNegativeShiftCount (!right .isZeroOrPositive ());
1342
- return factory ().createInt (op (left .getValue (), right .intValue ()));
1465
+ Object doPInt (PInt left , PInt right ) {
1466
+ return doHugeShift (left .getValue (), right );
1343
1467
}
1344
1468
1345
1469
private void raiseNegativeShiftCount (boolean cond ) {
@@ -1354,8 +1478,19 @@ PNotImplemented doGeneric(Object a, Object b) {
1354
1478
return PNotImplemented .NOT_IMPLEMENTED ;
1355
1479
}
1356
1480
1481
+ private Object doHugeShift (BigInteger left , PInt right ) {
1482
+ raiseNegativeShiftCount (!right .isZeroOrPositive ());
1483
+ try {
1484
+ return factory ().createInt (op (left , right .intValueExact ()));
1485
+ } catch (ArithmeticException e ) {
1486
+ // right is >= 2**31, BigInteger's bitLength is at most 2**31-1
1487
+ // therefore the result of shifting right is just the sign bit
1488
+ return left .signum () < 0 ? -1 : 0 ;
1489
+ }
1490
+ }
1491
+
1357
1492
@ TruffleBoundary
1358
- public static BigInteger op (BigInteger left , int right ) {
1493
+ private static BigInteger op (BigInteger left , int right ) {
1359
1494
return left .shiftRight (right );
1360
1495
}
1361
1496
0 commit comments