Skip to content

Commit 76b8001

Browse files
committed
Implementation of math functions comb and perm
1 parent 87a2109 commit 76b8001

File tree

2 files changed

+166
-0
lines changed

2 files changed

+166
-0
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/MathModuleBuiltins.java

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,171 @@ protected static FactorialNode create() {
516516
}
517517
}
518518

519+
@Builtin(name = "comb", minNumOfPositionalArgs = 2)
520+
@TypeSystemReference(PythonArithmeticTypes.class)
521+
@GenerateNodeFactory
522+
@ImportStatic(MathGuards.class)
523+
public abstract static class CombNode extends PythonBinaryBuiltinNode {
524+
525+
@TruffleBoundary
526+
private BigInteger calculateComb(BigInteger n, BigInteger k) {
527+
if (n.signum() < 0) {
528+
throw raise(ValueError, ErrorMessages.MUST_BE_NON_NEGATIVE_INTEGER, "n");
529+
}
530+
if (k.signum() < 0) {
531+
throw raise(ValueError, ErrorMessages.MUST_BE_NON_NEGATIVE_INTEGER, "k");
532+
}
533+
534+
BigInteger factors = k.min(n.subtract(k));
535+
if (factors.signum() < 0) {
536+
return BigInteger.ZERO;
537+
}
538+
if (factors.signum() == 0) {
539+
return BigInteger.ONE;
540+
}
541+
BigInteger result = n;
542+
BigInteger factor = n;
543+
BigInteger i = BigInteger.ONE;
544+
while (i.compareTo(factors) < 0) {
545+
factor = factor.subtract(BigInteger.ONE);
546+
result = result.multiply(factor);
547+
i = i.add(BigInteger.ONE);
548+
result = result.divide(i);
549+
}
550+
return result;
551+
}
552+
553+
@Specialization
554+
PInt comb(long n, long k) {
555+
return factory().createInt(calculateComb(PInt.longToBigInteger(n), PInt.longToBigInteger(k)));
556+
}
557+
558+
@Specialization
559+
PInt comb(long n, PInt k) {
560+
return factory().createInt(calculateComb(PInt.longToBigInteger(n), k.getValue()));
561+
}
562+
563+
@Specialization
564+
PInt comb(PInt n, long k) {
565+
return factory().createInt(calculateComb(n.getValue(), PInt.longToBigInteger(k)));
566+
}
567+
568+
@Specialization
569+
PInt comb(PInt n, PInt k) {
570+
return factory().createInt(calculateComb(n.getValue(), k.getValue()));
571+
}
572+
573+
@Specialization
574+
int comb(@SuppressWarnings("unused") double n, @SuppressWarnings("unused") Object k) {
575+
throw raise(TypeError, ErrorMessages.OBJ_CANNOT_BE_INTERPRETED_AS_INTEGER, "float");
576+
}
577+
578+
@Specialization
579+
int comb(@SuppressWarnings("unused") Object n, @SuppressWarnings("unused") double k) {
580+
throw raise(TypeError, ErrorMessages.OBJ_CANNOT_BE_INTERPRETED_AS_INTEGER, "float");
581+
}
582+
583+
@Specialization(guards = "!isNumber(n) || !isNumber(k)")
584+
Object comb(VirtualFrame frame, Object n, Object k,
585+
@Cached("createBinaryProfile()") ConditionProfile hasFrame,
586+
@CachedLibrary(limit = "2") PythonObjectLibrary lib,
587+
@Cached CombNode recursiveNode) {
588+
Object nValue = lib.asIndexWithFrame(n, hasFrame, frame);
589+
Object kValue = lib.asIndexWithFrame(k, hasFrame, frame);
590+
return recursiveNode.execute(frame, nValue, kValue);
591+
}
592+
593+
public static CombNode create() {
594+
return MathModuleBuiltinsFactory.CombNodeFactory.create();
595+
}
596+
}
597+
598+
@Builtin(name = "perm", minNumOfPositionalArgs = 1, parameterNames = {"n", "k"})
599+
@TypeSystemReference(PythonArithmeticTypes.class)
600+
@GenerateNodeFactory
601+
@ImportStatic(MathGuards.class)
602+
public abstract static class PermNode extends PythonBinaryBuiltinNode {
603+
604+
@TruffleBoundary
605+
private BigInteger calculatePerm(BigInteger n, BigInteger k) {
606+
if (n.signum() < 0) {
607+
throw raise(ValueError, ErrorMessages.MUST_BE_NON_NEGATIVE_INTEGER, "n");
608+
}
609+
if (k.signum() < 0) {
610+
throw raise(ValueError, ErrorMessages.MUST_BE_NON_NEGATIVE_INTEGER, "k");
611+
}
612+
if (n.compareTo(k) < 0) {
613+
return BigInteger.ZERO;
614+
}
615+
if (k.equals(BigInteger.ZERO)) {
616+
return BigInteger.ONE;
617+
}
618+
if (k.equals(BigInteger.ONE)) {
619+
return n;
620+
}
621+
622+
BigInteger result = n;
623+
BigInteger factor = n;
624+
BigInteger i = BigInteger.ONE;
625+
while (i.compareTo(k) < 0) {
626+
factor = factor.subtract(BigInteger.ONE);
627+
result = result.multiply(factor);
628+
i = i.add(BigInteger.ONE);
629+
}
630+
return result;
631+
}
632+
633+
@Specialization
634+
PInt perm(long n, long k) {
635+
return factory().createInt(calculatePerm(PInt.longToBigInteger(n), PInt.longToBigInteger(k)));
636+
}
637+
638+
@Specialization
639+
PInt perm(long n, PInt k) {
640+
return factory().createInt(calculatePerm(PInt.longToBigInteger(n), k.getValue()));
641+
}
642+
643+
@Specialization
644+
PInt perm(PInt n, long k) {
645+
return factory().createInt(calculatePerm(n.getValue(), PInt.longToBigInteger(k)));
646+
}
647+
648+
@Specialization
649+
PInt perm(PInt n, PInt k) {
650+
return factory().createInt(calculatePerm(n.getValue(), k.getValue()));
651+
}
652+
653+
@Specialization
654+
int perm(@SuppressWarnings("unused") double n, @SuppressWarnings("unused") Object k) {
655+
throw raise(TypeError, ErrorMessages.OBJ_CANNOT_BE_INTERPRETED_AS_INTEGER, "float");
656+
}
657+
658+
@Specialization
659+
int perm(@SuppressWarnings("unused") Object n, @SuppressWarnings("unused") double k) {
660+
throw raise(TypeError, ErrorMessages.OBJ_CANNOT_BE_INTERPRETED_AS_INTEGER, "float");
661+
}
662+
663+
@Specialization
664+
Object perm(VirtualFrame frame, Object n, @SuppressWarnings("unused") PNone k,
665+
@Cached FactorialNode factorialNode) {
666+
return factorialNode.execute(frame, n);
667+
}
668+
669+
@Specialization(guards = "!isNumber(n) || !isNumber(k)")
670+
Object perm(VirtualFrame frame, Object n, Object k,
671+
@Cached("createBinaryProfile()") ConditionProfile hasFrame,
672+
@CachedLibrary(limit = "2") PythonObjectLibrary lib,
673+
@Cached PermNode recursiveNode) {
674+
Object nValue = lib.asIndexWithFrame(n, hasFrame, frame);
675+
Object kValue = lib.asIndexWithFrame(k, hasFrame, frame);
676+
return recursiveNode.execute(frame, nValue, kValue);
677+
}
678+
679+
public static PermNode create() {
680+
return MathModuleBuiltinsFactory.PermNodeFactory.create();
681+
}
682+
}
683+
519684
@Builtin(name = "floor", minNumOfPositionalArgs = 1)
520685
@GenerateNodeFactory
521686
@ImportStatic(MathGuards.class)

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/ErrorMessages.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ public abstract class ErrorMessages {
308308
public static final String MUST_BE_BYTE_STRING_LEGTH1_NOT_P = "must be a byte string of length 1, not %p";
309309
public static final String MUST_BE_EITHER_OR = "%s: '%s' must be either %s or %s";
310310
public static final String MUST_BE_NON_NEGATIVE = "%s must be non-negative";
311+
public static final String MUST_BE_NON_NEGATIVE_INTEGER = "%s must be non-negative integer";
311312
public static final String MUST_BE_NUMERIC = "must be numeric, not %p";
312313
public static final String MUST_BE_REAL_NUMBER = "must be real number, not %p";
313314
public static final String MUST_BE_S_NOT_P = "%s must be a %s, not %p";

0 commit comments

Comments
 (0)