|
43 | 43 | import com.oracle.graal.python.builtins.PythonBuiltinClassType;
|
44 | 44 | import com.oracle.graal.python.builtins.PythonBuiltins;
|
45 | 45 | import com.oracle.graal.python.builtins.objects.PNone;
|
| 46 | +import com.oracle.graal.python.builtins.objects.common.SequenceNodes; |
46 | 47 | import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes;
|
47 | 48 | import com.oracle.graal.python.builtins.objects.floats.PFloat;
|
| 49 | +import com.oracle.graal.python.builtins.objects.function.PArguments; |
48 | 50 | import com.oracle.graal.python.builtins.objects.function.PKeyword;
|
49 | 51 | import com.oracle.graal.python.builtins.objects.ints.PInt;
|
50 | 52 | import com.oracle.graal.python.builtins.objects.object.PythonObjectLibrary;
|
51 | 53 | import com.oracle.graal.python.builtins.objects.tuple.PTuple;
|
52 | 54 | import com.oracle.graal.python.nodes.ErrorMessages;
|
53 | 55 | import com.oracle.graal.python.nodes.PGuards;
|
| 56 | +import com.oracle.graal.python.nodes.builtins.TupleNodes; |
54 | 57 | import com.oracle.graal.python.nodes.call.special.LookupAndCallBinaryNode;
|
55 | 58 | import com.oracle.graal.python.nodes.call.special.LookupAndCallUnaryNode;
|
56 | 59 | import com.oracle.graal.python.nodes.control.GetIteratorExpressionNode.GetIteratorNode;
|
@@ -2686,4 +2689,58 @@ public Object doGeneric(VirtualFrame frame, Object iterable, Object start) {
|
2686 | 2689 | }
|
2687 | 2690 | }
|
2688 | 2691 | }
|
| 2692 | + |
| 2693 | + @Builtin(name = "dist", minNumOfPositionalArgs = 2, numOfPositionalOnlyArgs = 2, parameterNames = {"p", "q"}) |
| 2694 | + @GenerateNodeFactory |
| 2695 | + public abstract static class DistNode extends PythonBuiltinNode { |
| 2696 | + |
| 2697 | + @Child private TupleNodes.ConstructTupleNode tupleCtor = TupleNodes.ConstructTupleNode.create(); |
| 2698 | + @Child private SequenceNodes.GetObjectArrayNode getObjectArray = SequenceNodes.GetObjectArrayNode.create(); |
| 2699 | + |
| 2700 | + @Specialization |
| 2701 | + public double doGeneric(VirtualFrame frame, Object p, Object q, |
| 2702 | + @CachedLibrary(limit = "4") PythonObjectLibrary lib) { |
| 2703 | + // adapted from CPython math_dist_impl and vector_norm |
| 2704 | + Object[] ps = getObjectArray.execute(tupleCtor.execute(frame, p)); |
| 2705 | + Object[] qs = getObjectArray.execute(tupleCtor.execute(frame, q)); |
| 2706 | + int len = ps.length; |
| 2707 | + if (len != qs.length) { |
| 2708 | + throw raise(ValueError, ErrorMessages.BOTH_POINTS_MUST_HAVE_THE_SAME_NUMBER_OF_DIMENSIONS); |
| 2709 | + } |
| 2710 | + double[] diffs = new double[len]; |
| 2711 | + double max = 0.0; |
| 2712 | + boolean foundNan = false; |
| 2713 | + for (int i = 0; i < len; ++i) { |
| 2714 | + double a = lib.asJavaDoubleWithState(ps[i], PArguments.getThreadState(frame)); |
| 2715 | + double b = lib.asJavaDoubleWithState(qs[i], PArguments.getThreadState(frame)); |
| 2716 | + double x = Math.abs(a - b); |
| 2717 | + diffs[i] = x; |
| 2718 | + foundNan |= Double.isNaN(x); |
| 2719 | + if (x > max) { |
| 2720 | + max = x; |
| 2721 | + } |
| 2722 | + } |
| 2723 | + if (Double.isInfinite(max)) { |
| 2724 | + return max; |
| 2725 | + } |
| 2726 | + if (foundNan) { |
| 2727 | + return Double.NaN; |
| 2728 | + } |
| 2729 | + if (max == 0.0 || len <= 1) { |
| 2730 | + return max; |
| 2731 | + } |
| 2732 | + |
| 2733 | + double csum = 1.0; |
| 2734 | + double frac = 0.0; |
| 2735 | + for (int i = 0; i < len; ++i) { |
| 2736 | + double x = diffs[i]; |
| 2737 | + x /= max; |
| 2738 | + x = x * x; |
| 2739 | + double oldcsum = csum; |
| 2740 | + csum += x; |
| 2741 | + frac += (oldcsum - csum) + x; |
| 2742 | + } |
| 2743 | + return max * Math.sqrt(csum - 1.0 + frac); |
| 2744 | + } |
| 2745 | + } |
2689 | 2746 | }
|
0 commit comments