Skip to content

Commit 99af87b

Browse files
committed
[GR-23312] Implementation of math functions comb, perm and remainder, fixed math.factorial()
PullRequest: graalpython/1005
2 parents 8d5bb33 + 2114d7a commit 99af87b

File tree

2 files changed

+199
-7
lines changed

2 files changed

+199
-7
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/MathModuleBuiltins.java

Lines changed: 197 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ public int factorialBoolean(@SuppressWarnings("unused") boolean value) {
367367

368368
@Specialization(guards = {"value < 0"})
369369
public long factorialNegativeInt(@SuppressWarnings("unused") int value) {
370-
throw raise(PythonErrorType.ValueError, ErrorMessages.FACTORIAL_NOT_DEFNED_FOR_NEGATIVE);
370+
throw raise(PythonErrorType.ValueError, ErrorMessages.FACTORIAL_NOT_DEFINED_FOR_NEGATIVE);
371371
}
372372

373373
@Specialization(guards = {"0 <= value", "value < SMALL_FACTORIALS.length"})
@@ -382,7 +382,7 @@ public PInt factorialInt(int value) {
382382

383383
@Specialization(guards = {"value < 0"})
384384
public long factorialNegativeLong(@SuppressWarnings("unused") long value) {
385-
throw raise(PythonErrorType.ValueError, ErrorMessages.FACTORIAL_NOT_DEFNED_FOR_NEGATIVE);
385+
throw raise(PythonErrorType.ValueError, ErrorMessages.FACTORIAL_NOT_DEFINED_FOR_NEGATIVE);
386386
}
387387

388388
@Specialization(guards = {"0 <= value", "value < SMALL_FACTORIALS.length"})
@@ -397,7 +397,7 @@ public PInt factorialLong(long value) {
397397

398398
@Specialization(guards = "isNegative(value)")
399399
public Object factorialPINegative(@SuppressWarnings("unused") PInt value) {
400-
throw raise(PythonErrorType.ValueError, ErrorMessages.FACTORIAL_NOT_DEFNED_FOR_NEGATIVE);
400+
throw raise(PythonErrorType.ValueError, ErrorMessages.FACTORIAL_NOT_DEFINED_FOR_NEGATIVE);
401401
}
402402

403403
@Specialization(guards = "isOvf(value)")
@@ -427,7 +427,7 @@ public long factorialDoubleInfinite(@SuppressWarnings("unused") double value) {
427427

428428
@Specialization(guards = "isNegative(value)")
429429
public PInt factorialDoubleNegative(@SuppressWarnings("unused") double value) {
430-
throw raise(PythonErrorType.ValueError, ErrorMessages.FACTORIAL_NOT_DEFNED_FOR_NEGATIVE);
430+
throw raise(PythonErrorType.ValueError, ErrorMessages.FACTORIAL_NOT_DEFINED_FOR_NEGATIVE);
431431
}
432432

433433
@Specialization(guards = "!isInteger(value)")
@@ -461,7 +461,7 @@ public long factorialPFLInfinite(@SuppressWarnings("unused") PFloat value) {
461461

462462
@Specialization(guards = "isNegative(value.getValue())")
463463
public PInt factorialPFLNegative(@SuppressWarnings("unused") PFloat value) {
464-
throw raise(PythonErrorType.ValueError, ErrorMessages.FACTORIAL_NOT_DEFNED_FOR_NEGATIVE);
464+
throw raise(PythonErrorType.ValueError, ErrorMessages.FACTORIAL_NOT_DEFINED_FOR_NEGATIVE);
465465
}
466466

467467
@Specialization(guards = "!isInteger(value.getValue())")
@@ -488,7 +488,7 @@ public Object factorialPFL(PFloat value) {
488488
public Object factorialObject(VirtualFrame frame, Object value,
489489
@CachedLibrary("value") PythonObjectLibrary lib,
490490
@Cached("create()") FactorialNode recursiveNode) {
491-
return recursiveNode.execute(frame, lib.asPInt(value));
491+
return recursiveNode.execute(frame, lib.asIndex(value));
492492
}
493493

494494
protected boolean isInteger(double value) {
@@ -516,6 +516,151 @@ protected static FactorialNode create() {
516516
}
517517
}
518518

519+
@Builtin(name = "comb", minNumOfPositionalArgs = 2)
520+
@TypeSystemReference(PythonArithmeticTypes.class)
521+
@GenerateNodeFactory
522+
@ImportStatic(MathGuards.class)
523+
public abstract static class CombNode extends PythonBinaryBuiltinNode {
524+
525+
@TruffleBoundary
526+
private BigInteger calculateComb(BigInteger n, BigInteger k) {
527+
if (n.signum() < 0) {
528+
throw raise(ValueError, ErrorMessages.MUST_BE_NON_NEGATIVE_INTEGER, "n");
529+
}
530+
if (k.signum() < 0) {
531+
throw raise(ValueError, ErrorMessages.MUST_BE_NON_NEGATIVE_INTEGER, "k");
532+
}
533+
534+
BigInteger factors = k.min(n.subtract(k));
535+
if (factors.signum() < 0) {
536+
return BigInteger.ZERO;
537+
}
538+
if (factors.signum() == 0) {
539+
return BigInteger.ONE;
540+
}
541+
BigInteger result = n;
542+
BigInteger factor = n;
543+
BigInteger i = BigInteger.ONE;
544+
while (i.compareTo(factors) < 0) {
545+
factor = factor.subtract(BigInteger.ONE);
546+
result = result.multiply(factor);
547+
i = i.add(BigInteger.ONE);
548+
result = result.divide(i);
549+
}
550+
return result;
551+
}
552+
553+
@Specialization
554+
PInt comb(long n, long k) {
555+
return factory().createInt(calculateComb(PInt.longToBigInteger(n), PInt.longToBigInteger(k)));
556+
}
557+
558+
@Specialization
559+
PInt comb(long n, PInt k) {
560+
return factory().createInt(calculateComb(PInt.longToBigInteger(n), k.getValue()));
561+
}
562+
563+
@Specialization
564+
PInt comb(PInt n, long k) {
565+
return factory().createInt(calculateComb(n.getValue(), PInt.longToBigInteger(k)));
566+
}
567+
568+
@Specialization
569+
PInt comb(PInt n, PInt k) {
570+
return factory().createInt(calculateComb(n.getValue(), k.getValue()));
571+
}
572+
573+
@Specialization
574+
Object comb(VirtualFrame frame, Object n, Object k,
575+
@Cached("createBinaryProfile()") ConditionProfile hasFrame,
576+
@CachedLibrary(limit = "2") PythonObjectLibrary lib,
577+
@Cached CombNode recursiveNode) {
578+
Object nValue = lib.asIndexWithFrame(n, hasFrame, frame);
579+
Object kValue = lib.asIndexWithFrame(k, hasFrame, frame);
580+
return recursiveNode.execute(frame, nValue, kValue);
581+
}
582+
583+
public static CombNode create() {
584+
return MathModuleBuiltinsFactory.CombNodeFactory.create();
585+
}
586+
}
587+
588+
@Builtin(name = "perm", minNumOfPositionalArgs = 1, parameterNames = {"n", "k"})
589+
@TypeSystemReference(PythonArithmeticTypes.class)
590+
@GenerateNodeFactory
591+
@ImportStatic(MathGuards.class)
592+
public abstract static class PermNode extends PythonBinaryBuiltinNode {
593+
594+
@TruffleBoundary
595+
private BigInteger calculatePerm(BigInteger n, BigInteger k) {
596+
if (n.signum() < 0) {
597+
throw raise(ValueError, ErrorMessages.MUST_BE_NON_NEGATIVE_INTEGER, "n");
598+
}
599+
if (k.signum() < 0) {
600+
throw raise(ValueError, ErrorMessages.MUST_BE_NON_NEGATIVE_INTEGER, "k");
601+
}
602+
if (n.compareTo(k) < 0) {
603+
return BigInteger.ZERO;
604+
}
605+
if (k.equals(BigInteger.ZERO)) {
606+
return BigInteger.ONE;
607+
}
608+
if (k.equals(BigInteger.ONE)) {
609+
return n;
610+
}
611+
612+
BigInteger result = n;
613+
BigInteger factor = n;
614+
BigInteger i = BigInteger.ONE;
615+
while (i.compareTo(k) < 0) {
616+
factor = factor.subtract(BigInteger.ONE);
617+
result = result.multiply(factor);
618+
i = i.add(BigInteger.ONE);
619+
}
620+
return result;
621+
}
622+
623+
@Specialization
624+
PInt perm(long n, long k) {
625+
return factory().createInt(calculatePerm(PInt.longToBigInteger(n), PInt.longToBigInteger(k)));
626+
}
627+
628+
@Specialization
629+
PInt perm(long n, PInt k) {
630+
return factory().createInt(calculatePerm(PInt.longToBigInteger(n), k.getValue()));
631+
}
632+
633+
@Specialization
634+
PInt perm(PInt n, long k) {
635+
return factory().createInt(calculatePerm(n.getValue(), PInt.longToBigInteger(k)));
636+
}
637+
638+
@Specialization
639+
PInt perm(PInt n, PInt k) {
640+
return factory().createInt(calculatePerm(n.getValue(), k.getValue()));
641+
}
642+
643+
@Specialization
644+
Object perm(VirtualFrame frame, Object n, @SuppressWarnings("unused") PNone k,
645+
@Cached FactorialNode factorialNode) {
646+
return factorialNode.execute(frame, n);
647+
}
648+
649+
@Specialization(guards = "!isPNone(k)")
650+
Object perm(VirtualFrame frame, Object n, Object k,
651+
@Cached("createBinaryProfile()") ConditionProfile hasFrame,
652+
@CachedLibrary(limit = "2") PythonObjectLibrary lib,
653+
@Cached PermNode recursiveNode) {
654+
Object nValue = lib.asIndexWithFrame(n, hasFrame, frame);
655+
Object kValue = lib.asIndexWithFrame(k, hasFrame, frame);
656+
return recursiveNode.execute(frame, nValue, kValue);
657+
}
658+
659+
public static PermNode create() {
660+
return MathModuleBuiltinsFactory.PermNodeFactory.create();
661+
}
662+
}
663+
519664
@Builtin(name = "floor", minNumOfPositionalArgs = 1)
520665
@GenerateNodeFactory
521666
@ImportStatic(MathGuards.class)
@@ -689,6 +834,52 @@ protected void raiseMathDomainError(boolean con) {
689834

690835
}
691836

837+
@Builtin(name = "remainder", minNumOfPositionalArgs = 2)
838+
@TypeSystemReference(PythonArithmeticTypes.class)
839+
@ImportStatic(MathGuards.class)
840+
@GenerateNodeFactory
841+
public abstract static class RemainderNode extends PythonBinaryBuiltinNode {
842+
843+
@Specialization
844+
double remainderDD(double x, double y) {
845+
if (Double.isFinite(x) && Double.isFinite(y)) {
846+
if (y == 0.0) {
847+
throw raise(ValueError, ErrorMessages.MATH_DOMAIN_ERROR);
848+
}
849+
double absx = Math.abs(x);
850+
double absy = Math.abs(y);
851+
double m = absx % absy;
852+
double c = absy - m;
853+
double r;
854+
if (m < c) {
855+
r = m;
856+
} else if (m > c) {
857+
r = -c;
858+
} else {
859+
r = m - 2.0 * ((0.5 * (absx - m)) % absy);
860+
}
861+
return Math.copySign(1.0, x) * r;
862+
}
863+
if (Double.isNaN(x)) {
864+
return x;
865+
}
866+
if (Double.isNaN(y)) {
867+
return y;
868+
}
869+
if (Double.isInfinite(x)) {
870+
throw raise(ValueError, ErrorMessages.MATH_DOMAIN_ERROR);
871+
}
872+
return x;
873+
}
874+
875+
@Specialization(limit = "1")
876+
double remainderOO(Object x, Object y,
877+
@CachedLibrary("x") PythonObjectLibrary xLib,
878+
@CachedLibrary("y") PythonObjectLibrary yLib) {
879+
return remainderDD(xLib.asJavaDouble(x), yLib.asJavaDouble(y));
880+
}
881+
}
882+
692883
@Builtin(name = "frexp", minNumOfPositionalArgs = 1)
693884
@TypeSystemReference(PythonArithmeticTypes.class)
694885
@ImportStatic(MathGuards.class)

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/ErrorMessages.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ public abstract class ErrorMessages {
191191
public static final String EXPECTED_STR_BYTE_OSPATHLIKE_OBJ = "expected str, bytes or os.PathLike object, not %p";
192192
public static final String EXPECTED_UNICODE_CHAR_NOT_P = "expected a unicode character, not %p";
193193
public static final String EXPONENT_TOO_LARGE = "exponent too large";
194-
public static final String FACTORIAL_NOT_DEFNED_FOR_NEGATIVE = "factorial() not defined for negative values";
194+
public static final String FACTORIAL_NOT_DEFINED_FOR_NEGATIVE = "factorial() not defined for negative values";
195195
public static final String FILE_NOT_OPENED_FOR_READING = "file not opened for reading";
196196
public static final String FILL_CHAR_MUST_BE_LENGTH_1 = "The fill character must be exactly one character long";
197197
public static final String FILTER_SPEC_MUST_BE_DICT = "Filter specifier must be a dict or dict-like object";
@@ -309,6 +309,7 @@ public abstract class ErrorMessages {
309309
public static final String MUST_BE_BYTE_STRING_LEGTH1_NOT_P = "must be a byte string of length 1, not %p";
310310
public static final String MUST_BE_EITHER_OR = "%s: '%s' must be either %s or %s";
311311
public static final String MUST_BE_NON_NEGATIVE = "%s must be non-negative";
312+
public static final String MUST_BE_NON_NEGATIVE_INTEGER = "%s must be non-negative integer";
312313
public static final String MUST_BE_NUMERIC = "must be numeric, not %p";
313314
public static final String MUST_BE_REAL_NUMBER = "must be real number, not %p";
314315
public static final String MUST_BE_S_NOT_P = "%s must be a %s, not %p";

0 commit comments

Comments
 (0)