Skip to content

Commit ae9b154

Browse files
author
Michal Medvecky
committed
class pattern match
1 parent 2607d6a commit ae9b154

File tree

2 files changed

+201
-10
lines changed

2 files changed

+201
-10
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/compiler/bytecode_dsl/RootNodeCompiler.java

Lines changed: 189 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3640,7 +3640,7 @@ private void visitPattern(PatternTy pattern, PatternContext pc) {
36403640
if (pattern instanceof PatternTy.MatchAs matchAs) {
36413641
doVisitPattern(matchAs, pc);
36423642
} else if (pattern instanceof PatternTy.MatchClass matchClass) {
3643-
doVisitPattern(matchClass);
3643+
doVisitPattern(matchClass, pc);
36443644
} else if (pattern instanceof PatternTy.MatchMapping matchMapping) {
36453645
doVisitPattern(matchMapping, pc);
36463646
} else if (pattern instanceof PatternTy.MatchOr matchOr) {
@@ -3698,8 +3698,192 @@ private void emitPatternNotImplemented(String kind) {
36983698
b.endBlock();
36993699
}
37003700

3701-
private void doVisitPattern(PatternTy.MatchClass node) {
3702-
emitPatternNotImplemented("class");
3701+
/**
3702+
* Saves subject of the pattern context into BytecodeLocal variable, to be restored eventually.
3703+
* @param pc Pattern context, which subject needs to be saved.
3704+
* @return Subject saved in local variable.
3705+
*/
3706+
private BytecodeLocal patternContextSubjectSave(PatternContext pc) {
3707+
BytecodeLocal pcSave = b.createLocal();
3708+
b.beginStoreLocal(pcSave);
3709+
b.emitLoadLocal(pc.subject);
3710+
b.endStoreLocal();
3711+
return pcSave;
3712+
}
3713+
3714+
/**
3715+
* Loads pattern context subject back into pattern context.
3716+
* @param pcSave Variable to restore pattern context subject from.
3717+
* @param pc Pattern context into which the subject should be restored.
3718+
*/
3719+
private void patternContextSubjectLoad(BytecodeLocal pcSave, PatternContext pc) {
3720+
b.beginStoreLocal(pc.subject);
3721+
b.emitLoadLocal(pcSave);
3722+
b.endStoreLocal();
3723+
}
3724+
3725+
/**
3726+
* Check if attribute and keyword attribute lengths match, or if there isn't too much patterns or attributes.
3727+
* Throws error on fail.
3728+
* @param patLen Patterns count
3729+
* @param attrsLen Attributes count
3730+
* @param kwdPatLen Keyword attributes count
3731+
* @param node MatchClass node for errors
3732+
*/
3733+
private void classMatchLengthChecks(int patLen, int attrsLen, int kwdPatLen, PatternTy.MatchClass node) {
3734+
if (attrsLen != kwdPatLen) {
3735+
ctx.errorCallback.onError(ErrorType.Syntax, node.getSourceRange(), "kwd_attrs (%d) / kwd_patterns (%d) length mismatch in class pattern", attrsLen, kwdPatLen);
3736+
}
3737+
if (Integer.MAX_VALUE < patLen + attrsLen - 1) {
3738+
String id = node.cls instanceof ExprTy.Name ? ((ExprTy.Name) node.cls).id : node.cls.toString();
3739+
ctx.errorCallback.onError(ErrorType.Syntax, node.getSourceRange(), "too many sub-patterns in class pattern %s", id);
3740+
}
3741+
3742+
}
3743+
3744+
/**
3745+
* Visits sub-patterns for class pattern matching. Regular, positional patterns are handled first, then the
3746+
* keyword patterns (e.g. the "class.attribute = [keyword] pattern"). Generates boolean value based on
3747+
* results of the subpatterns; values are evaluated using the AND operator.
3748+
* @param patterns Patterns to check as subpatterns.
3749+
* @param kwdPatterns Keyword patterns to check as subpatterns.
3750+
* @param attrsValueUnpacked Values to use as `pc.subject` in sub-pattern check.
3751+
* @param pc Pattern context (subject is saved then restored).
3752+
* @param patLen Number of patterns.
3753+
* @param attrsLen Number of attributes (also keyword patterns).
3754+
*/
3755+
private void classMatchVisitSubpatterns(PatternTy[] patterns, PatternTy[] kwdPatterns, BytecodeLocal attrsValueUnpacked, PatternContext pc, int patLen, int attrsLen) {
3756+
BytecodeLocal pcSave = patternContextSubjectSave(pc);
3757+
3758+
if (patLen + attrsLen == 0) {
3759+
b.emitLoadConstant(true);
3760+
} else {
3761+
BytecodeLocal temp = b.createLocal();
3762+
b.beginStoreLocal(temp);
3763+
b.beginPrimitiveBoolAnd();
3764+
for (int i = 0; i < patLen; i++) {
3765+
b.beginBlock();
3766+
b.beginStoreLocal(pc.subject);
3767+
b.beginArrayIndex(i);
3768+
b.emitLoadLocal(attrsValueUnpacked);
3769+
b.endArrayIndex();
3770+
b.endStoreLocal();
3771+
3772+
visitSubpattern(patterns[i], pc);
3773+
b.endBlock();
3774+
}
3775+
3776+
for (int i = 0, j = patLen; i < attrsLen; i++, j++) {
3777+
b.beginBlock();
3778+
b.beginStoreLocal(pc.subject);
3779+
b.beginArrayIndex(j);
3780+
b.emitLoadLocal(attrsValueUnpacked);
3781+
b.endArrayIndex();
3782+
b.endStoreLocal();
3783+
3784+
visitSubpattern(kwdPatterns[i], pc);
3785+
b.endBlock();
3786+
}
3787+
b.endPrimitiveBoolAnd();
3788+
b.endStoreLocal();
3789+
3790+
patternContextSubjectLoad(pcSave, pc);
3791+
3792+
b.emitLoadLocal(temp);
3793+
}
3794+
}
3795+
3796+
private void doVisitPattern(PatternTy.MatchClass node, PatternContext pc) {
3797+
/**
3798+
* Class pattern matching consists of subject and pattern. Pattern is split into:
3799+
* <ul>
3800+
* <li> patterns: These are positional and match the {@code __match_args__} arguments of the class, and are
3801+
* evaluated as sub-patterns with respective positional class attributes as subjects.
3802+
* <li> keyword attributes (kwdAttrs): These are non-positional, named class attributes that need to match
3803+
* the accompanying keyword patterns.
3804+
* <li> keyword patterns (kwdPatterns): Patterns that accompany keyword attributes, these are evaluated as
3805+
* sub-patterns with provided class attributes as subjects. Note that the number of keyword attributes
3806+
* and keyword patterns do need to match.
3807+
* </ul>
3808+
*
3809+
* Example:
3810+
* @formatter:off
3811+
* x = <some class>
3812+
* match x:
3813+
* case <class>(x, 42 as y, a = ("test1" | "test2") as z):
3814+
* ...
3815+
* @formatter:on
3816+
* Here, {@code x} and {@code 42 as y} are "patterns" (positional), {@code a} is "keyword attribute" and
3817+
* {@code ... as z} is its accompanying "keyword pattern".
3818+
*/
3819+
3820+
b.beginBlock();
3821+
3822+
PatternTy[] patterns = node.patterns;
3823+
String[] kwdAttrs = node.kwdAttrs;
3824+
PatternTy[] kwdPatterns = node.kwdPatterns;
3825+
int patLen = lengthOrZero(patterns);
3826+
int attrsLen = lengthOrZero(kwdAttrs);
3827+
int kwdPatLen = lengthOrZero(kwdPatterns);
3828+
3829+
classMatchLengthChecks(patLen, attrsLen, kwdPatLen, node);
3830+
if (attrsLen > 0) {
3831+
validateKwdAttrs(kwdAttrs, kwdPatterns);
3832+
}
3833+
3834+
//@formatter:off
3835+
// attributes needs to be converted into truffle strings
3836+
TruffleString[] tsAttrs = new TruffleString[attrsLen];
3837+
for (int i = 0; i < attrsLen; i++) {
3838+
tsAttrs[i] = toTruffleStringUncached(kwdAttrs[i]);
3839+
}
3840+
3841+
b.beginPrimitiveBoolAnd();
3842+
BytecodeLocal attrsValue = b.createLocal();
3843+
// match class that's in the subject
3844+
b.beginMatchClass(attrsValue);
3845+
b.emitLoadLocal(pc.subject);
3846+
node.cls.accept(this); // get class type
3847+
b.emitLoadConstant(patLen);
3848+
b.emitLoadConstant(tsAttrs);
3849+
b.endMatchClass();
3850+
3851+
b.beginBlock();
3852+
// attributes from match class needs to be unpacked first
3853+
BytecodeLocal attrsValueUnpacked = b.createLocal();
3854+
b.beginStoreLocal(attrsValueUnpacked);
3855+
b.beginUnpackSequence(patLen + attrsLen);
3856+
b.emitLoadLocal(attrsValue);
3857+
b.endUnpackSequence();
3858+
b.endStoreLocal();
3859+
3860+
classMatchVisitSubpatterns(patterns, kwdPatterns, attrsValueUnpacked, pc, patLen, attrsLen);
3861+
b.endBlock();
3862+
b.endPrimitiveBoolAnd();
3863+
3864+
b.endBlock();
3865+
//@formatter:on
3866+
}
3867+
3868+
/**
3869+
* Checks if keyword argument names aren't the same or if their name isn't forbidden. Raises error at fail.
3870+
* @param attrs Attributes to check.
3871+
* @param patterns Patterns for error source range.
3872+
*/
3873+
private void validateKwdAttrs(String[] attrs, PatternTy[] patterns) {
3874+
// Any errors will point to the pattern rather than the arg name as the
3875+
// parser is only supplying identifiers rather than Name or keyword nodes
3876+
int attrsLen = lengthOrZero(attrs);
3877+
for (int i = 0; i < attrsLen; i++) {
3878+
String attr = attrs[i];
3879+
checkForbiddenName(attr, NameOperation.BeginWrite, patterns[i].getSourceRange());
3880+
for (int j = i + 1; j < attrsLen; j++) {
3881+
String other = attrs[j];
3882+
if (attr.equals(other)) {
3883+
ctx.errorCallback.onError(ErrorType.Syntax, patterns[j].getSourceRange(), "attribute name repeated in class pattern: `%s`", attr);
3884+
}
3885+
}
3886+
}
37033887
}
37043888

37053889
private static int lengthOrZero(Object[] p) {
@@ -3787,10 +3971,7 @@ private void mappingVisitSubpatterns(PatternTy[] patterns, BytecodeLocal values,
37873971
b.endStoreLocal();
37883972

37893973
// backup pc.subject, it will get replaced for sub-patterns
3790-
BytecodeLocal pcSave = b.createLocal();
3791-
b.beginStoreLocal(pcSave);
3792-
b.emitLoadLocal(pc.subject);
3793-
b.endStoreLocal();
3974+
BytecodeLocal pcSave = patternContextSubjectSave(pc);
37943975

37953976
BytecodeLocal temp = b.createLocal();
37963977
b.beginStoreLocal(temp);
@@ -3817,9 +3998,7 @@ private void mappingVisitSubpatterns(PatternTy[] patterns, BytecodeLocal values,
38173998
b.endPrimitiveBoolAnd();
38183999
b.endStoreLocal();
38194000

3820-
b.beginStoreLocal(pc.subject);
3821-
b.emitLoadLocal(pcSave);
3822-
b.endStoreLocal();
4001+
patternContextSubjectLoad(pcSave, pc);
38234002

38244003
b.emitLoadLocal(temp);
38254004
b.endBlock();

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/bytecode_dsl/PBytecodeDSLRootNode.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@
169169
import com.oracle.graal.python.nodes.bytecode.ImportFromNode;
170170
import com.oracle.graal.python.nodes.bytecode.ImportNode;
171171
import com.oracle.graal.python.nodes.bytecode.ImportStarNode;
172+
import com.oracle.graal.python.nodes.bytecode.MatchClassNode;
172173
import com.oracle.graal.python.nodes.bytecode.MatchKeysNode;
173174
import com.oracle.graal.python.nodes.bytecode.PrintExprNode;
174175
import com.oracle.graal.python.nodes.bytecode.RaiseNode;
@@ -1142,6 +1143,17 @@ public static PDict perform(VirtualFrame frame, Object map, Object[] keys, @Cach
11421143
}
11431144
}
11441145

1146+
@Operation
1147+
@ConstantOperand(type = LocalAccessor.class)
1148+
public static final class MatchClass {
1149+
@Specialization
1150+
public static Object perform(VirtualFrame frame, LocalAccessor attributes, Object subject, Object type, int nargs, TruffleString[] kwArgs, @Bind BytecodeNode bytecodeNode, @Cached MatchClassNode node) {
1151+
Object attrs = node.execute(frame, subject, type, nargs, kwArgs);
1152+
attributes.setObject(bytecodeNode, frame, attrs);
1153+
return attrs != null;
1154+
}
1155+
}
1156+
11451157
@Operation
11461158
@ConstantOperand(type = TruffleString.class, name = "name")
11471159
@ConstantOperand(type = TruffleString.class, name = "qualifiedName")

0 commit comments

Comments
 (0)