@@ -2548,4 +2548,77 @@ public double count(double x) {
2548
2548
}
2549
2549
2550
2550
}
2551
+
2552
+ @ Builtin (name = "isqrt" , minNumOfPositionalArgs = 1 )
2553
+ @ TypeSystemReference (PythonArithmeticTypes .class )
2554
+ @ GenerateNodeFactory
2555
+ @ ImportStatic (MathGuards .class )
2556
+ public abstract static class IsqrtNode extends PythonUnaryBuiltinNode {
2557
+
2558
+ @ Specialization
2559
+ Object isqrtLong (long x ) {
2560
+ raiseIfNegative (x < 0 );
2561
+ return makeInt (op (PInt .longToBigInteger (x )));
2562
+ }
2563
+
2564
+ @ Specialization
2565
+ Object isqrtPInt (PInt x ) {
2566
+ raiseIfNegative (x .isNegative ());
2567
+ return makeInt (op (x .getValue ()));
2568
+ }
2569
+
2570
+ @ Specialization (guards = "!isInteger(x)" )
2571
+ Object doGeneral (VirtualFrame frame , Object x ,
2572
+ @ Cached ("createBinaryProfile()" ) ConditionProfile hasFrame ,
2573
+ @ CachedLibrary (limit = "1" ) PythonObjectLibrary lib ,
2574
+ @ Cached IsqrtNode recursiveNode ) {
2575
+ return recursiveNode .execute (frame , lib .asIndexWithFrame (x , hasFrame , frame ));
2576
+ }
2577
+
2578
+ private Object makeInt (BigInteger i ) {
2579
+ try {
2580
+ return PInt .intValueExact (i );
2581
+ } catch (ArithmeticException e ) {
2582
+ // does not fit int, so try long
2583
+ }
2584
+ try {
2585
+ return PInt .longValueExact (i );
2586
+ } catch (ArithmeticException e ) {
2587
+ // does not fit long either, create PInt
2588
+ }
2589
+ return factory ().createInt (i );
2590
+ }
2591
+
2592
+ @ TruffleBoundary
2593
+ private BigInteger op (BigInteger x ) {
2594
+ // assumes x >= 0
2595
+ if (x .equals (BigInteger .ZERO ) || x .equals (BigInteger .ONE )) {
2596
+ return x ;
2597
+ }
2598
+ BigInteger start = BigInteger .ONE ;
2599
+ BigInteger end = x ;
2600
+ BigInteger result = BigInteger .ZERO ;
2601
+ BigInteger two = BigInteger .valueOf (2 );
2602
+ while (start .compareTo (end ) <= 0 ) {
2603
+ BigInteger mid = (start .add (end ).divide (two ));
2604
+ int cmp = mid .multiply (mid ).compareTo (x );
2605
+ if (cmp == 0 ) {
2606
+ return mid ;
2607
+ }
2608
+ if (cmp < 0 ) {
2609
+ start = mid .add (BigInteger .ONE );
2610
+ result = mid ;
2611
+ } else {
2612
+ end = mid .subtract (BigInteger .ONE );
2613
+ }
2614
+ }
2615
+ return result ;
2616
+ }
2617
+
2618
+ private void raiseIfNegative (boolean condition ) {
2619
+ if (condition ) {
2620
+ throw raise (ValueError , ErrorMessages .MUST_BE_NON_NEGATIVE , "isqrt() argument" );
2621
+ }
2622
+ }
2623
+ }
2551
2624
}
0 commit comments