@@ -516,6 +516,171 @@ protected static FactorialNode create() {
516
516
}
517
517
}
518
518
519
+ @ Builtin (name = "comb" , minNumOfPositionalArgs = 2 )
520
+ @ TypeSystemReference (PythonArithmeticTypes .class )
521
+ @ GenerateNodeFactory
522
+ @ ImportStatic (MathGuards .class )
523
+ public abstract static class CombNode extends PythonBinaryBuiltinNode {
524
+
525
+ @ TruffleBoundary
526
+ private BigInteger calculateComb (BigInteger n , BigInteger k ) {
527
+ if (n .signum () < 0 ) {
528
+ throw raise (ValueError , ErrorMessages .MUST_BE_NON_NEGATIVE_INTEGER , "n" );
529
+ }
530
+ if (k .signum () < 0 ) {
531
+ throw raise (ValueError , ErrorMessages .MUST_BE_NON_NEGATIVE_INTEGER , "k" );
532
+ }
533
+
534
+ BigInteger factors = k .min (n .subtract (k ));
535
+ if (factors .signum () < 0 ) {
536
+ return BigInteger .ZERO ;
537
+ }
538
+ if (factors .signum () == 0 ) {
539
+ return BigInteger .ONE ;
540
+ }
541
+ BigInteger result = n ;
542
+ BigInteger factor = n ;
543
+ BigInteger i = BigInteger .ONE ;
544
+ while (i .compareTo (factors ) < 0 ) {
545
+ factor = factor .subtract (BigInteger .ONE );
546
+ result = result .multiply (factor );
547
+ i = i .add (BigInteger .ONE );
548
+ result = result .divide (i );
549
+ }
550
+ return result ;
551
+ }
552
+
553
+ @ Specialization
554
+ PInt comb (long n , long k ) {
555
+ return factory ().createInt (calculateComb (PInt .longToBigInteger (n ), PInt .longToBigInteger (k )));
556
+ }
557
+
558
+ @ Specialization
559
+ PInt comb (long n , PInt k ) {
560
+ return factory ().createInt (calculateComb (PInt .longToBigInteger (n ), k .getValue ()));
561
+ }
562
+
563
+ @ Specialization
564
+ PInt comb (PInt n , long k ) {
565
+ return factory ().createInt (calculateComb (n .getValue (), PInt .longToBigInteger (k )));
566
+ }
567
+
568
+ @ Specialization
569
+ PInt comb (PInt n , PInt k ) {
570
+ return factory ().createInt (calculateComb (n .getValue (), k .getValue ()));
571
+ }
572
+
573
+ @ Specialization
574
+ int comb (@ SuppressWarnings ("unused" ) double n , @ SuppressWarnings ("unused" ) Object k ) {
575
+ throw raise (TypeError , ErrorMessages .OBJ_CANNOT_BE_INTERPRETED_AS_INTEGER , "float" );
576
+ }
577
+
578
+ @ Specialization
579
+ int comb (@ SuppressWarnings ("unused" ) Object n , @ SuppressWarnings ("unused" ) double k ) {
580
+ throw raise (TypeError , ErrorMessages .OBJ_CANNOT_BE_INTERPRETED_AS_INTEGER , "float" );
581
+ }
582
+
583
+ @ Specialization (guards = "!isNumber(n) || !isNumber(k)" )
584
+ Object comb (VirtualFrame frame , Object n , Object k ,
585
+ @ Cached ("createBinaryProfile()" ) ConditionProfile hasFrame ,
586
+ @ CachedLibrary (limit = "2" ) PythonObjectLibrary lib ,
587
+ @ Cached CombNode recursiveNode ) {
588
+ Object nValue = lib .asIndexWithFrame (n , hasFrame , frame );
589
+ Object kValue = lib .asIndexWithFrame (k , hasFrame , frame );
590
+ return recursiveNode .execute (frame , nValue , kValue );
591
+ }
592
+
593
+ public static CombNode create () {
594
+ return MathModuleBuiltinsFactory .CombNodeFactory .create ();
595
+ }
596
+ }
597
+
598
+ @ Builtin (name = "perm" , minNumOfPositionalArgs = 1 , parameterNames = {"n" , "k" })
599
+ @ TypeSystemReference (PythonArithmeticTypes .class )
600
+ @ GenerateNodeFactory
601
+ @ ImportStatic (MathGuards .class )
602
+ public abstract static class PermNode extends PythonBinaryBuiltinNode {
603
+
604
+ @ TruffleBoundary
605
+ private BigInteger calculatePerm (BigInteger n , BigInteger k ) {
606
+ if (n .signum () < 0 ) {
607
+ throw raise (ValueError , ErrorMessages .MUST_BE_NON_NEGATIVE_INTEGER , "n" );
608
+ }
609
+ if (k .signum () < 0 ) {
610
+ throw raise (ValueError , ErrorMessages .MUST_BE_NON_NEGATIVE_INTEGER , "k" );
611
+ }
612
+ if (n .compareTo (k ) < 0 ) {
613
+ return BigInteger .ZERO ;
614
+ }
615
+ if (k .equals (BigInteger .ZERO )) {
616
+ return BigInteger .ONE ;
617
+ }
618
+ if (k .equals (BigInteger .ONE )) {
619
+ return n ;
620
+ }
621
+
622
+ BigInteger result = n ;
623
+ BigInteger factor = n ;
624
+ BigInteger i = BigInteger .ONE ;
625
+ while (i .compareTo (k ) < 0 ) {
626
+ factor = factor .subtract (BigInteger .ONE );
627
+ result = result .multiply (factor );
628
+ i = i .add (BigInteger .ONE );
629
+ }
630
+ return result ;
631
+ }
632
+
633
+ @ Specialization
634
+ PInt perm (long n , long k ) {
635
+ return factory ().createInt (calculatePerm (PInt .longToBigInteger (n ), PInt .longToBigInteger (k )));
636
+ }
637
+
638
+ @ Specialization
639
+ PInt perm (long n , PInt k ) {
640
+ return factory ().createInt (calculatePerm (PInt .longToBigInteger (n ), k .getValue ()));
641
+ }
642
+
643
+ @ Specialization
644
+ PInt perm (PInt n , long k ) {
645
+ return factory ().createInt (calculatePerm (n .getValue (), PInt .longToBigInteger (k )));
646
+ }
647
+
648
+ @ Specialization
649
+ PInt perm (PInt n , PInt k ) {
650
+ return factory ().createInt (calculatePerm (n .getValue (), k .getValue ()));
651
+ }
652
+
653
+ @ Specialization
654
+ int perm (@ SuppressWarnings ("unused" ) double n , @ SuppressWarnings ("unused" ) Object k ) {
655
+ throw raise (TypeError , ErrorMessages .OBJ_CANNOT_BE_INTERPRETED_AS_INTEGER , "float" );
656
+ }
657
+
658
+ @ Specialization
659
+ int perm (@ SuppressWarnings ("unused" ) Object n , @ SuppressWarnings ("unused" ) double k ) {
660
+ throw raise (TypeError , ErrorMessages .OBJ_CANNOT_BE_INTERPRETED_AS_INTEGER , "float" );
661
+ }
662
+
663
+ @ Specialization
664
+ Object perm (VirtualFrame frame , Object n , @ SuppressWarnings ("unused" ) PNone k ,
665
+ @ Cached FactorialNode factorialNode ) {
666
+ return factorialNode .execute (frame , n );
667
+ }
668
+
669
+ @ Specialization (guards = "!isNumber(n) || !isNumber(k)" )
670
+ Object perm (VirtualFrame frame , Object n , Object k ,
671
+ @ Cached ("createBinaryProfile()" ) ConditionProfile hasFrame ,
672
+ @ CachedLibrary (limit = "2" ) PythonObjectLibrary lib ,
673
+ @ Cached PermNode recursiveNode ) {
674
+ Object nValue = lib .asIndexWithFrame (n , hasFrame , frame );
675
+ Object kValue = lib .asIndexWithFrame (k , hasFrame , frame );
676
+ return recursiveNode .execute (frame , nValue , kValue );
677
+ }
678
+
679
+ public static PermNode create () {
680
+ return MathModuleBuiltinsFactory .PermNodeFactory .create ();
681
+ }
682
+ }
683
+
519
684
@ Builtin (name = "floor" , minNumOfPositionalArgs = 1 )
520
685
@ GenerateNodeFactory
521
686
@ ImportStatic (MathGuards .class )
0 commit comments