Skip to content

Commit d364e0c

Browse files
committed
Refactoring of Math.LogNode
1 parent 4d08afe commit d364e0c

File tree

1 file changed

+80
-98
lines changed

1 file changed

+80
-98
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/MathModuleBuiltins.java

Lines changed: 80 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,14 @@
4242
import com.oracle.graal.python.builtins.objects.ints.PInt;
4343
import com.oracle.graal.python.builtins.objects.tuple.PTuple;
4444
import com.oracle.graal.python.nodes.PNode;
45+
import com.oracle.graal.python.nodes.SpecialMethodNames;
4546
import com.oracle.graal.python.nodes.call.special.LookupAndCallUnaryNode;
4647
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
4748
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
4849
import com.oracle.graal.python.nodes.truffle.PythonArithmeticTypes;
4950
import com.oracle.graal.python.runtime.exception.PythonErrorType;
5051
import com.oracle.graal.python.runtime.object.PythonObjectFactory;
52+
import com.oracle.truffle.api.CompilerDirectives;
5153
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
5254
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
5355
import com.oracle.truffle.api.dsl.Cached;
@@ -1152,6 +1154,34 @@ protected static IsInfNode create() {
11521154
@GenerateNodeFactory
11531155
public abstract static class LogNode extends PythonUnaryBuiltinNode {
11541156

1157+
@Child private LookupAndCallUnaryNode valueDispatchNode;
1158+
@Child private LookupAndCallUnaryNode baseDispatchNode;
1159+
@Child private LogNode recLogNode;
1160+
1161+
private LookupAndCallUnaryNode getValueDispatchNode() {
1162+
if (valueDispatchNode == null) {
1163+
CompilerDirectives.transferToInterpreterAndInvalidate();
1164+
valueDispatchNode = insert(LookupAndCallUnaryNode.create(SpecialMethodNames.__FLOAT__));
1165+
}
1166+
return valueDispatchNode;
1167+
}
1168+
1169+
private LookupAndCallUnaryNode getBaseDispatchNode() {
1170+
if (baseDispatchNode == null) {
1171+
CompilerDirectives.transferToInterpreterAndInvalidate();
1172+
baseDispatchNode = insert(LookupAndCallUnaryNode.create(SpecialMethodNames.__FLOAT__));
1173+
}
1174+
return baseDispatchNode;
1175+
}
1176+
1177+
private double executeRecursiveLogNode(Object value, Object base) {
1178+
if (recLogNode == null) {
1179+
CompilerDirectives.transferToInterpreterAndInvalidate();
1180+
recLogNode = insert(LogNode.create());
1181+
}
1182+
return recLogNode.executeObject(value, base);
1183+
}
1184+
11551185
public abstract double executeObject(Object value, Object base);
11561186

11571187
private static final double LOG2 = Math.log(2.0);
@@ -1179,7 +1209,7 @@ private double countBase(BigInteger base, ConditionProfile divByZero) {
11791209
}
11801210
return logBase;
11811211
}
1182-
1212+
11831213
@Specialization
11841214
public double log(long value, @SuppressWarnings("unused") PNone novalue,
11851215
@Cached("createBinaryProfile()") ConditionProfile doNotFit) {
@@ -1189,39 +1219,35 @@ public double log(long value, @SuppressWarnings("unused") PNone novalue,
11891219
@Specialization
11901220
public double logDN(double value, @SuppressWarnings("unused") PNone novalue,
11911221
@Cached("createBinaryProfile()") ConditionProfile doNotFit) {
1192-
if (doNotFit.profile(value < 0)) {
1193-
throw raise(ValueError, "math domain error");
1194-
}
1222+
raiseMathError(doNotFit, value < 0);
11951223
return Math.log(value);
11961224
}
11971225

11981226
@Specialization
11991227
@TruffleBoundary
1200-
public double log(PInt value, @SuppressWarnings("unused") PNone novalue,
1228+
public double logPIN(PInt value, @SuppressWarnings("unused") PNone novalue,
12011229
@Cached("createBinaryProfile()") ConditionProfile doNotFit) {
12021230
BigInteger bValue = value.getValue();
1203-
if (doNotFit.profile(bValue.compareTo(BigInteger.ZERO) == -1)) {
1204-
throw raise(ValueError, "math domain error");
1205-
}
1231+
raiseMathError(doNotFit, bValue.compareTo(BigInteger.ZERO) == -1);
12061232
return logBigInteger(bValue);
12071233
}
12081234

12091235
@Specialization
1210-
public double log(long value, long base,
1236+
public double logLL(long value, long base,
12111237
@Cached("createBinaryProfile()") ConditionProfile doNotFit,
12121238
@Cached("createBinaryProfile()") ConditionProfile divByZero) {
12131239
return logDD(value, base, doNotFit, divByZero);
12141240
}
12151241

12161242
@Specialization
1217-
public double log(double value, long base,
1243+
public double logDL(double value, long base,
12181244
@Cached("createBinaryProfile()") ConditionProfile doNotFit,
12191245
@Cached("createBinaryProfile()") ConditionProfile divByZero) {
12201246
return logDD(value, base, doNotFit, divByZero);
12211247
}
12221248

12231249
@Specialization
1224-
public double log(long value, double base,
1250+
public double logLD(long value, double base,
12251251
@Cached("createBinaryProfile()") ConditionProfile doNotFit,
12261252
@Cached("createBinaryProfile()") ConditionProfile divByZero) {
12271253
return logDD(value, base, doNotFit, divByZero);
@@ -1231,28 +1257,24 @@ public double log(long value, double base,
12311257
public double logDD(double value, double base,
12321258
@Cached("createBinaryProfile()") ConditionProfile doNotFit,
12331259
@Cached("createBinaryProfile()") ConditionProfile divByZero) {
1234-
if (doNotFit.profile(value < 0 || base <= 0)) {
1235-
throw raise(ValueError, "math domain error");
1236-
}
1260+
raiseMathError(doNotFit, value < 0 || base <= 0);
12371261
double logBase = countBase(base, divByZero);
12381262
return Math.log(value) / logBase;
12391263
}
12401264

12411265
@Specialization
12421266
@TruffleBoundary
1243-
public double logDD(double value, PInt base,
1267+
public double logDPI(double value, PInt base,
12441268
@Cached("createBinaryProfile()") ConditionProfile doNotFit,
12451269
@Cached("createBinaryProfile()") ConditionProfile divByZero) {
12461270
BigInteger bBase = base.getValue();
1247-
if (doNotFit.profile(value < 0 || bBase.compareTo(BigInteger.ZERO) <= 0)) {
1248-
throw raise(ValueError, "math domain error");
1249-
}
1271+
raiseMathError(doNotFit, value < 0 || bBase.compareTo(BigInteger.ZERO) <= 0);
12501272
double logBase = countBase(bBase, divByZero);
12511273
return Math.log(value) / logBase;
12521274
}
12531275

12541276
@Specialization
1255-
public double log(PInt value, long base,
1277+
public double logPIL(PInt value, long base,
12561278
@Cached("createBinaryProfile()") ConditionProfile doNotFit,
12571279
@Cached("createBinaryProfile()") ConditionProfile divByZero) {
12581280
return logPID(value, base, doNotFit, divByZero);
@@ -1264,122 +1286,82 @@ public double logPID(PInt value, double base,
12641286
@Cached("createBinaryProfile()") ConditionProfile doNotFit,
12651287
@Cached("createBinaryProfile()") ConditionProfile divByZero) {
12661288
BigInteger bValue = value.getValue();
1267-
if (doNotFit.profile(bValue.compareTo(BigInteger.ZERO) == -1 || base <= 0)) {
1268-
throw raise(ValueError, "math domain error");
1269-
}
1289+
raiseMathError(doNotFit, bValue.compareTo(BigInteger.ZERO) == -1 || base <= 0);
12701290
double logBase = countBase(base, divByZero);
12711291
return logBigInteger(bValue) / logBase;
12721292
}
12731293

12741294
@Specialization
12751295
@TruffleBoundary
1276-
public double log(long value, PInt base,
1296+
public double logLPI(long value, PInt base,
12771297
@Cached("createBinaryProfile()") ConditionProfile doNotFit,
12781298
@Cached("createBinaryProfile()") ConditionProfile divByZero) {
12791299
BigInteger bBase = base.getValue();
1280-
if (doNotFit.profile(value < 0 || bBase.compareTo(BigInteger.ZERO) <= 0)) {
1281-
throw raise(ValueError, "math domain error");
1282-
}
1300+
raiseMathError(doNotFit, value < 0 || bBase.compareTo(BigInteger.ZERO) <= 0);
12831301
double logBase = countBase(bBase, divByZero);
12841302
return Math.log(value) / logBase;
12851303
}
12861304

12871305
@Specialization
12881306
@TruffleBoundary
1289-
public double log(PInt value, PInt base,
1307+
public double logPIPI(PInt value, PInt base,
12901308
@Cached("createBinaryProfile()") ConditionProfile doNotFit,
12911309
@Cached("createBinaryProfile()") ConditionProfile divByZero) {
12921310
BigInteger bValue = value.getValue();
12931311
BigInteger bBase = base.getValue();
1294-
if (doNotFit.profile(bValue.compareTo(BigInteger.ZERO) == -1 || bBase.compareTo(BigInteger.ZERO) <= 0)) {
1295-
throw raise(ValueError, "math domain error");
1296-
}
1312+
raiseMathError(doNotFit, bValue.compareTo(BigInteger.ZERO) == -1 || bBase.compareTo(BigInteger.ZERO) <= 0);
12971313
double logBase = countBase(bBase, divByZero);
12981314
return logBigInteger(bValue) / logBase;
12991315
}
13001316

13011317
@Specialization(guards = "!isNumber(value)")
1302-
public double log(Object value, @SuppressWarnings("unused") PNone novalue,
1303-
@Cached("create(__FLOAT__)") LookupAndCallUnaryNode dispatchFloat,
1304-
@Cached("create()") LogNode logNode) {
1305-
Object result = dispatchFloat.executeObject(value);
1306-
if (result == PNone.NO_VALUE) {
1307-
throw raise(TypeError, "must be real number, not %p", value);
1308-
}
1309-
return logNode.executeObject(result, novalue);
1318+
public double logO(Object value, @SuppressWarnings("unused") PNone novalue,
1319+
@Cached("createBinaryProfile()") ConditionProfile notNumber) {
1320+
Object result = getRealNumber(value, getValueDispatchNode(), notNumber);
1321+
return executeRecursiveLogNode(result, novalue);
13101322
}
13111323

1312-
@Specialization(guards = "!isNumber(value)")
1313-
public double log(Object value, long base,
1314-
@Cached("create(__FLOAT__)") LookupAndCallUnaryNode dispatchFloat,
1315-
@Cached("create()") LogNode logNode) {
1316-
return logOD(value, base, dispatchFloat, logNode);
1317-
}
1318-
1319-
@Specialization(guards = "!isNumber(value)")
1320-
public double logOD(Object value, double base,
1321-
@Cached("create(__FLOAT__)") LookupAndCallUnaryNode dispatchFloat,
1322-
@Cached("create()") LogNode logNode) {
1323-
Object result = dispatchFloat.executeObject(value);
1324-
if (result == PNone.NO_VALUE) {
1325-
throw raise(TypeError, "must be real number, not %p", value);
1326-
}
1327-
return logNode.executeObject(result, base);
1328-
}
1329-
1330-
@Specialization(guards = "!isNumber(value)")
1331-
public double log(Object value, PInt base,
1332-
@Cached("create(__FLOAT__)") LookupAndCallUnaryNode dispatchFloat,
1333-
@Cached("create()") LogNode logNode) {
1334-
Object result = dispatchFloat.executeObject(value);
1335-
if (result == PNone.NO_VALUE) {
1336-
throw raise(TypeError, "must be real number, not %p", value);
1337-
}
1338-
return logNode.executeObject(result, base);
1339-
}
1340-
1341-
@Specialization(guards = {"!isNumber(value)", "!isNumber(base)"})
1342-
public double log(Object value, Object base,
1343-
@Cached("create(__FLOAT__)") LookupAndCallUnaryNode dispatchFloat,
1344-
@Cached("create()") LogNode logNode) {
1345-
Object resultValue = dispatchFloat.executeObject(value);
1346-
if (resultValue == PNone.NO_VALUE) {
1347-
throw raise(TypeError, "must be real number, not %p", value);
1348-
}
1349-
Object resultBase = dispatchFloat.executeObject(base);
1350-
if (resultBase == PNone.NO_VALUE) {
1351-
throw raise(TypeError, "must be real number, not %p", base);
1352-
}
1353-
return logNode.executeObject(resultValue, resultBase);
1324+
@Specialization(guards = {"!isNumber(value)", "!isNoValue(base)"})
1325+
public double logOO(Object value, Object base,
1326+
@Cached("createBinaryProfile()") ConditionProfile notNumberValue,
1327+
@Cached("createBinaryProfile()") ConditionProfile notNumberBase) {
1328+
Object resultValue = getRealNumber(value, getValueDispatchNode(), notNumberValue);
1329+
Object resultBase = getRealNumber(base, getBaseDispatchNode(), notNumberBase);
1330+
return executeRecursiveLogNode(resultValue, resultBase);
13541331
}
13551332

13561333
@Specialization(guards = {"!isNumber(base)"})
1357-
public double log(long value, Object base,
1358-
@Cached("create(__FLOAT__)") LookupAndCallUnaryNode dispatchFloat,
1359-
@Cached("create()") LogNode logNode) {
1360-
return logDO(value, base, dispatchFloat, logNode);
1334+
public double logLO(long value, Object base,
1335+
@Cached("createBinaryProfile()") ConditionProfile notNumberBase) {
1336+
return logDO(value, base, notNumberBase);
13611337
}
13621338

13631339
@Specialization(guards = {"!isNumber(base)"})
13641340
public double logDO(double value, Object base,
1365-
@Cached("create(__FLOAT__)") LookupAndCallUnaryNode dispatchFloat,
1366-
@Cached("create()") LogNode logNode) {
1367-
Object resultBase = dispatchFloat.executeObject(base);
1368-
if (resultBase == PNone.NO_VALUE) {
1369-
throw raise(TypeError, "must be real number, not %p", base);
1370-
}
1371-
return logNode.executeObject(value, resultBase);
1341+
@Cached("createBinaryProfile()") ConditionProfile notNumberBase) {
1342+
Object resultBase = getRealNumber(base, getBaseDispatchNode(), notNumberBase);
1343+
return executeRecursiveLogNode(value, resultBase);
13721344
}
13731345

13741346
@Specialization(guards = {"!isNumber(base)"})
1375-
public double log(PInt value, Object base,
1376-
@Cached("create(__FLOAT__)") LookupAndCallUnaryNode dispatchFloat,
1377-
@Cached("create()") LogNode logNode) {
1378-
Object resultBase = dispatchFloat.executeObject(base);
1379-
if (resultBase == PNone.NO_VALUE) {
1380-
throw raise(TypeError, "must be real number, not %p", base);
1347+
public double logPIO(PInt value, Object base,
1348+
@Cached("createBinaryProfile()") ConditionProfile notNumberBase) {
1349+
Object resultBase = getRealNumber(base, getBaseDispatchNode(), notNumberBase);
1350+
return executeRecursiveLogNode(value, resultBase);
1351+
}
1352+
1353+
private void raiseMathError(ConditionProfile doNotFit, boolean con) {
1354+
if (doNotFit.profile(con)) {
1355+
throw raise(ValueError, "math domain error");
13811356
}
1382-
return logNode.executeObject(value, resultBase);
1357+
}
1358+
1359+
private Object getRealNumber(Object object, LookupAndCallUnaryNode dispatchNode, ConditionProfile isNotRealNumber) {
1360+
Object result = dispatchNode.executeObject(object);
1361+
if (result == PNone.NO_VALUE) {
1362+
throw raise(TypeError, "must be real number, not %p", object);
1363+
}
1364+
return result;
13831365
}
13841366

13851367
public static LogNode create() {

0 commit comments

Comments
 (0)