|
70 | 70 | import com.oracle.graal.python.compiler.Unparser;
|
71 | 71 | import com.oracle.graal.python.compiler.bytecode_dsl.BytecodeDSLCompiler.BytecodeDSLCompilerContext;
|
72 | 72 | import com.oracle.graal.python.compiler.bytecode_dsl.BytecodeDSLCompiler.BytecodeDSLCompilerResult;
|
| 73 | +import com.oracle.graal.python.lib.PyObjectRichCompareBool; |
73 | 74 | import com.oracle.graal.python.nodes.StringLiterals;
|
74 | 75 | import com.oracle.graal.python.nodes.bytecode_dsl.BytecodeDSLCodeUnit;
|
75 | 76 | import com.oracle.graal.python.nodes.bytecode_dsl.PBytecodeDSLRootNode;
|
@@ -3641,7 +3642,7 @@ private void visitPattern(PatternTy pattern, PatternContext pc) {
|
3641 | 3642 | } else if (pattern instanceof PatternTy.MatchClass matchClass) {
|
3642 | 3643 | doVisitPattern(matchClass);
|
3643 | 3644 | } else if (pattern instanceof PatternTy.MatchMapping matchMapping) {
|
3644 |
| - doVisitPattern(matchMapping); |
| 3645 | + doVisitPattern(matchMapping, pc); |
3645 | 3646 | } else if (pattern instanceof PatternTy.MatchOr matchOr) {
|
3646 | 3647 | doVisitPattern(matchOr, pc);
|
3647 | 3648 | } else if (pattern instanceof PatternTy.MatchSequence matchSequence) {
|
@@ -3701,8 +3702,155 @@ private void doVisitPattern(PatternTy.MatchClass node) {
|
3701 | 3702 | emitPatternNotImplemented("class");
|
3702 | 3703 | }
|
3703 | 3704 |
|
3704 |
| - private void doVisitPattern(PatternTy.MatchMapping node) { |
3705 |
| - emitPatternNotImplemented("mapping"); |
| 3705 | + private static int lengthOrZero(Object[] p) { |
| 3706 | + return p == null ? 0 : p.length; |
| 3707 | + } |
| 3708 | + |
| 3709 | + private void doVisitPattern(PatternTy.MatchMapping node, PatternContext pc) { |
| 3710 | + ExprTy[] keys = node.keys; |
| 3711 | + PatternTy[] patterns = node.patterns; |
| 3712 | + |
| 3713 | + int key_len = lengthOrZero(keys); |
| 3714 | + int pat_len = lengthOrZero(patterns); |
| 3715 | + |
| 3716 | + if (key_len != pat_len) { |
| 3717 | + ctx.errorCallback.onError(ErrorType.Syntax, node.getSourceRange(), "keys (%d) / patterns (%d) length mismatch in mapping pattern", key_len, pat_len); |
| 3718 | + } |
| 3719 | + |
| 3720 | + String starTarget = node.rest; |
| 3721 | + if (key_len == 0 && starTarget == null) { |
| 3722 | + b.emitLoadConstant(false); |
| 3723 | + return; |
| 3724 | + } |
| 3725 | + if (Integer.MAX_VALUE < key_len - 1) { |
| 3726 | + ctx.errorCallback.onError(ErrorType.Syntax, node.getSourceRange(), "too many sub-patterns in mapping pattern"); |
| 3727 | + } |
| 3728 | + |
| 3729 | + b.beginPrimitiveBoolAnd(); // AND for sanity checks (length matching, etc.) |
| 3730 | + |
| 3731 | + if (key_len > 0) { |
| 3732 | + // If the pattern has any keys in it, perform a length check: |
| 3733 | + b.beginGe(); |
| 3734 | + b.beginGetLen(); |
| 3735 | + b.emitLoadLocal(pc.subject); |
| 3736 | + b.endGetLen(); |
| 3737 | + b.emitLoadConstant(key_len); |
| 3738 | + b.endGe(); |
| 3739 | + } |
| 3740 | + |
| 3741 | + b.beginBlock(); |
| 3742 | + |
| 3743 | + // save pc.subject |
| 3744 | + BytecodeLocal pc_save = b.createLocal(); |
| 3745 | + b.beginStoreLocal(pc_save); |
| 3746 | + b.emitLoadLocal(pc.subject); |
| 3747 | + b.endStoreLocal(); |
| 3748 | + |
| 3749 | + // check that type matches |
| 3750 | + b.beginCheckTypeFlags(TypeFlags.MAPPING); |
| 3751 | + b.emitLoadLocal(pc.subject); |
| 3752 | + b.endCheckTypeFlags(); |
| 3753 | + |
| 3754 | + // match keys and get array of values |
| 3755 | + BytecodeLocal keys_local = b.createLocal(); |
| 3756 | + b.beginStoreLocal(keys_local); |
| 3757 | + b.beginCollectToObjectArray(); |
| 3758 | + List<Object> seen = new ArrayList<>(); |
| 3759 | + for (int i = 0; i < key_len; i++) { |
| 3760 | + ExprTy key = keys[i]; |
| 3761 | + if (key instanceof ExprTy.Attribute) { |
| 3762 | + key.accept(this); |
| 3763 | + } else { |
| 3764 | + ConstantValue constantValue = null; |
| 3765 | + if (key instanceof ExprTy.UnaryOp || key instanceof ExprTy.BinOp) { |
| 3766 | + constantValue = foldConstantOp(key); |
| 3767 | + } else if (key instanceof ExprTy.Constant) { |
| 3768 | + constantValue = ((ExprTy.Constant) key).value; |
| 3769 | + } else { |
| 3770 | + ctx.errorCallback.onError(ErrorType.Syntax, node.getSourceRange(), "mapping pattern keys may only match literals and attribute lookups"); |
| 3771 | + } |
| 3772 | + assert constantValue != null; |
| 3773 | + Object pythonValue = PythonUtils.pythonObjectFromConstantValue(constantValue); |
| 3774 | + for (Object o : seen) { |
| 3775 | + // need python like equal - e.g. 1 equals True |
| 3776 | + if (PyObjectRichCompareBool.executeEqUncached(o, pythonValue)) { |
| 3777 | + ctx.errorCallback.onError(ErrorType.Syntax, node.getSourceRange(), "mapping pattern checks duplicate key (%s)", pythonValue); |
| 3778 | + } |
| 3779 | + } |
| 3780 | + seen.add(pythonValue); |
| 3781 | + createConstant(constantValue); |
| 3782 | + } |
| 3783 | + } |
| 3784 | + b.endCollectToObjectArray(); |
| 3785 | + b.endStoreLocal(); |
| 3786 | + |
| 3787 | + // save match result AND values |
| 3788 | + BytecodeLocal key_match_and_values = b.createLocal(); |
| 3789 | + b.beginStoreLocal(key_match_and_values); |
| 3790 | + b.beginMatchKeys(); |
| 3791 | + b.emitLoadLocal(pc.subject); |
| 3792 | + b.emitLoadLocal(keys_local); |
| 3793 | + b.endMatchKeys(); |
| 3794 | + b.endStoreLocal(); |
| 3795 | + |
| 3796 | + BytecodeLocal temp = b.createLocal(); |
| 3797 | + b.beginStoreLocal(temp); |
| 3798 | + |
| 3799 | + b.beginPrimitiveBoolAnd(); // AND keys and values |
| 3800 | + |
| 3801 | + // emit if everything matched |
| 3802 | + b.beginArrayIndex(0); |
| 3803 | + b.emitLoadLocal(key_match_and_values); |
| 3804 | + b.endArrayIndex(); |
| 3805 | + |
| 3806 | + b.beginBlock(); |
| 3807 | + |
| 3808 | + // unpack values from pc.subject |
| 3809 | + BytecodeLocal values_unpacked = b.createLocal(); |
| 3810 | + b.beginStoreLocal(values_unpacked); |
| 3811 | + b.beginUnpackSequence(pat_len); |
| 3812 | + b.beginArrayIndex(1); |
| 3813 | + b.emitLoadLocal(key_match_and_values); |
| 3814 | + b.endArrayIndex(); |
| 3815 | + b.endUnpackSequence(); |
| 3816 | + b.endStoreLocal(); |
| 3817 | + |
| 3818 | + b.beginPrimitiveBoolAnd(); // AND for sub-pats |
| 3819 | + |
| 3820 | + for (int i = 0; i < pat_len; i++) { |
| 3821 | + b.beginBlock(); |
| 3822 | + |
| 3823 | + b.beginStoreLocal(pc.subject); |
| 3824 | + b.beginArrayIndex(i); |
| 3825 | + b.emitLoadLocal(values_unpacked); |
| 3826 | + b.endArrayIndex(); |
| 3827 | + b.endStoreLocal(); |
| 3828 | + |
| 3829 | + visitSubpattern(patterns[i], pc); |
| 3830 | + |
| 3831 | + b.endBlock(); |
| 3832 | + } |
| 3833 | + |
| 3834 | + b.emitLoadConstant(true); |
| 3835 | + |
| 3836 | + b.endPrimitiveBoolAnd(); // AND for sub-pats |
| 3837 | + |
| 3838 | + b.endBlock(); |
| 3839 | + |
| 3840 | + b.endPrimitiveBoolAnd(); // AND keys and values |
| 3841 | + |
| 3842 | + b.endStoreLocal(); // temp |
| 3843 | + |
| 3844 | + // restore saved pc.subject |
| 3845 | + b.beginStoreLocal(pc.subject); |
| 3846 | + b.emitLoadLocal(pc_save); |
| 3847 | + b.endStoreLocal(); |
| 3848 | + |
| 3849 | + b.emitLoadLocal(temp); |
| 3850 | + |
| 3851 | + b.endBlock(); |
| 3852 | + |
| 3853 | + b.endPrimitiveBoolAnd(); // AND for sanity checks (length matching, etc.) |
3706 | 3854 | }
|
3707 | 3855 |
|
3708 | 3856 | private void checkAlternativePatternDifferentNames(Set<String> control, Map<String, BytecodeLocal> bindVariables) {
|
|
0 commit comments