Skip to content

Commit 72632da

Browse files
committed
Implementation of math.dist()
1 parent 0137700 commit 72632da

File tree

4 files changed

+75
-1
lines changed

4 files changed

+75
-1
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/MathModuleBuiltins.java

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,17 @@
4343
import com.oracle.graal.python.builtins.PythonBuiltinClassType;
4444
import com.oracle.graal.python.builtins.PythonBuiltins;
4545
import com.oracle.graal.python.builtins.objects.PNone;
46+
import com.oracle.graal.python.builtins.objects.common.SequenceNodes;
4647
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes;
4748
import com.oracle.graal.python.builtins.objects.floats.PFloat;
49+
import com.oracle.graal.python.builtins.objects.function.PArguments;
4850
import com.oracle.graal.python.builtins.objects.function.PKeyword;
4951
import com.oracle.graal.python.builtins.objects.ints.PInt;
5052
import com.oracle.graal.python.builtins.objects.object.PythonObjectLibrary;
5153
import com.oracle.graal.python.builtins.objects.tuple.PTuple;
5254
import com.oracle.graal.python.nodes.ErrorMessages;
5355
import com.oracle.graal.python.nodes.PGuards;
56+
import com.oracle.graal.python.nodes.builtins.TupleNodes;
5457
import com.oracle.graal.python.nodes.call.special.LookupAndCallBinaryNode;
5558
import com.oracle.graal.python.nodes.call.special.LookupAndCallUnaryNode;
5659
import com.oracle.graal.python.nodes.control.GetIteratorExpressionNode.GetIteratorNode;
@@ -2686,4 +2689,58 @@ public Object doGeneric(VirtualFrame frame, Object iterable, Object start) {
26862689
}
26872690
}
26882691
}
2692+
2693+
@Builtin(name = "dist", minNumOfPositionalArgs = 2, numOfPositionalOnlyArgs = 2, parameterNames = {"p", "q"})
2694+
@GenerateNodeFactory
2695+
public abstract static class DistNode extends PythonBuiltinNode {
2696+
2697+
@Child private TupleNodes.ConstructTupleNode tupleCtor = TupleNodes.ConstructTupleNode.create();
2698+
@Child private SequenceNodes.GetObjectArrayNode getObjectArray = SequenceNodes.GetObjectArrayNode.create();
2699+
2700+
@Specialization
2701+
public double doGeneric(VirtualFrame frame, Object p, Object q,
2702+
@CachedLibrary(limit = "4") PythonObjectLibrary lib) {
2703+
// adapted from CPython math_dist_impl and vector_norm
2704+
Object[] ps = getObjectArray.execute(tupleCtor.execute(frame, p));
2705+
Object[] qs = getObjectArray.execute(tupleCtor.execute(frame, q));
2706+
int len = ps.length;
2707+
if (len != qs.length) {
2708+
throw raise(ValueError, ErrorMessages.BOTH_POINTS_MUST_HAVE_THE_SAME_NUMBER_OF_DIMENSIONS);
2709+
}
2710+
double[] diffs = new double[len];
2711+
double max = 0.0;
2712+
boolean foundNan = false;
2713+
for (int i = 0; i < len; ++i) {
2714+
double a = lib.asJavaDoubleWithState(ps[i], PArguments.getThreadState(frame));
2715+
double b = lib.asJavaDoubleWithState(qs[i], PArguments.getThreadState(frame));
2716+
double x = Math.abs(a - b);
2717+
diffs[i] = x;
2718+
foundNan |= Double.isNaN(x);
2719+
if (x > max) {
2720+
max = x;
2721+
}
2722+
}
2723+
if (Double.isInfinite(max)) {
2724+
return max;
2725+
}
2726+
if (foundNan) {
2727+
return Double.NaN;
2728+
}
2729+
if (max == 0.0 || len <= 1) {
2730+
return max;
2731+
}
2732+
2733+
double csum = 1.0;
2734+
double frac = 0.0;
2735+
for (int i = 0; i < len; ++i) {
2736+
double x = diffs[i];
2737+
x /= max;
2738+
x = x * x;
2739+
double oldcsum = csum;
2740+
csum += x;
2741+
frac += (oldcsum - csum) + x;
2742+
}
2743+
return max * Math.sqrt(csum - 1.0 + frac);
2744+
}
2745+
}
26892746
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/common/SequenceNodes.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,5 +125,9 @@ static Object[] doGeneric(Object seq,
125125
@Cached SequenceStorageNodes.ToArrayNode toArrayNode) {
126126
return toArrayNode.execute(getSequenceStorageNode.execute(seq));
127127
}
128+
129+
public static GetObjectArrayNode create() {
130+
return SequenceNodesFactory.GetObjectArrayNodeGen.create();
131+
}
128132
}
129133
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/ints/PInt.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,14 +246,26 @@ public boolean canBeJavaDouble() {
246246
public double asJavaDouble(
247247
@CachedLibrary("this") PythonObjectLibrary lib,
248248
@Exclusive @Cached CastToJavaDoubleNode castToDouble,
249-
@Exclusive @Cached() ConditionProfile hasIndexFunc,
249+
@Exclusive @Cached ConditionProfile hasIndexFunc,
250250
@Exclusive @Cached PRaiseNode raise) {
251251
if (hasIndexFunc.profile(lib.canBeIndex(this))) {
252252
return castToDouble.execute(lib.asIndex(this));
253253
}
254254
throw raise.raise(TypeError, ErrorMessages.MUST_BE_REAL_NUMBER, this);
255255
}
256256

257+
@ExportMessage
258+
public double asJavaDoubleWithState(ThreadState threadState,
259+
@CachedLibrary("this") PythonObjectLibrary lib,
260+
@Exclusive @Cached CastToJavaDoubleNode castToDouble,
261+
@Exclusive @Cached ConditionProfile hasIndexFunc,
262+
@Exclusive @Cached PRaiseNode raise) {
263+
if (hasIndexFunc.profile(lib.canBeIndex(this))) {
264+
return castToDouble.execute(lib.asIndexWithState(this, threadState));
265+
}
266+
throw raise.raise(TypeError, ErrorMessages.MUST_BE_REAL_NUMBER, this);
267+
}
268+
257269
@SuppressWarnings("static-method")
258270
@ExportMessage
259271
public boolean canBeJavaLong() {

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/ErrorMessages.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ public abstract class ErrorMessages {
8686
public static final String BASES_MUST_BE_TYPES = "bases must be types";
8787
public static final String BASES_ITEM_CAUSES_INHERITANCE_CYCLE = "a __bases__ item causes an inheritance cycle";
8888
public static final String BOOL_SHOULD_RETURN_BOOL = "__bool__ should return bool, returned %p";
89+
public static final String BOTH_POINTS_MUST_HAVE_THE_SAME_NUMBER_OF_DIMENSIONS = "both points must have the same number of dimensions";
8990
public static final String BUFFER_INDICES_MUST_BE_INTS = "buffer indices must be integers, not %p";
9091
public static final String BYTE_STR_IS_TOO_LARGE = "byte string is too large";
9192
public static final String BYTEARRAY_OUT_OF_BOUNDS = "bytearray index out of range";

0 commit comments

Comments
 (0)