Skip to content

Commit aa3141d

Browse files
committed
Fix: pattern matching with dynamic attribute
1 parent 0cd9cbe commit aa3141d

File tree

3 files changed

+39
-24
lines changed

3 files changed

+39
-24
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_patmat.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
1+
# Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved.
22
# DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
33
#
44
# The Universal Permissive License (UPL), Version 1.0
@@ -106,3 +106,20 @@ def star_match(x):
106106

107107
assert star_match(d) == {33:33}
108108

109+
def test_mutable_dict_keys():
110+
class MyObj:
111+
pass
112+
113+
def forward(**kwargs):
114+
return kwargs
115+
116+
def test(name):
117+
to_match = {'attr1': 1, 'attr2': 2, 'attr3': 3}
118+
x = MyObj()
119+
x.myattr = name
120+
match to_match:
121+
case {x.myattr: dyn_match, **data}:
122+
return forward(dyn_match=dyn_match, **data)
123+
124+
assert test('attr1') == {'dyn_match': 1, 'attr2': 2, 'attr3': 3}
125+
assert test('attr2') == {'dyn_match': 2, 'attr1': 1, 'attr3': 3}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/bytecode/CopyDictWithoutKeysNode.java

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* The Universal Permissive License (UPL), Version 1.0
@@ -47,31 +47,33 @@
4747
import com.oracle.graal.python.runtime.object.PythonObjectFactory;
4848
import com.oracle.truffle.api.CompilerAsserts;
4949
import com.oracle.truffle.api.dsl.Cached;
50+
import com.oracle.truffle.api.dsl.GenerateInline;
5051
import com.oracle.truffle.api.dsl.NeverDefault;
5152
import com.oracle.truffle.api.dsl.Specialization;
5253
import com.oracle.truffle.api.frame.Frame;
5354
import com.oracle.truffle.api.frame.VirtualFrame;
5455
import com.oracle.truffle.api.nodes.ExplodeLoop;
5556

57+
@GenerateInline(false) // used in BCI root node
5658
public abstract class CopyDictWithoutKeysNode extends PNodeWithContext {
5759
public abstract PDict execute(Frame frame, Object subject, Object[] keys);
5860

59-
@Specialization(guards = "keysArg.length <= 32")
60-
static PDict copy(VirtualFrame frame, Object subject, @NeverDefault @SuppressWarnings("unused") Object[] keysArg,
61-
@Cached(value = "keysArg", dimensions = 1) Object[] keys,
61+
@Specialization(guards = {"keys.length == keysLength", "keysLength <= 32"}, limit = "1")
62+
static PDict copy(VirtualFrame frame, Object subject, @NeverDefault @SuppressWarnings("unused") Object[] keys,
63+
@Cached("keys.length") int keysLength,
6264
@Cached PythonObjectFactory factory,
6365
@Cached DictNodes.UpdateNode updateNode,
6466
@Cached DictBuiltins.DelItemNode delItemNode) {
6567
PDict rest = factory.createDict();
6668
updateNode.execute(frame, rest, subject);
67-
deleteKeys(frame, keys, delItemNode, rest);
69+
deleteKeys(frame, keys, keysLength, delItemNode, rest);
6870
return rest;
6971
}
7072

7173
@ExplodeLoop
72-
private static void deleteKeys(VirtualFrame frame, Object[] keys, DictBuiltins.DelItemNode delItemNode, PDict rest) {
73-
CompilerAsserts.partialEvaluationConstant(keys);
74-
for (int i = 0; i < keys.length; i++) {
74+
private static void deleteKeys(VirtualFrame frame, Object[] keys, int keysLen, DictBuiltins.DelItemNode delItemNode, PDict rest) {
75+
CompilerAsserts.partialEvaluationConstant(keysLen);
76+
for (int i = 0; i < keysLen; i++) {
7577
delItemNode.execute(frame, rest, keys[i]);
7678
}
7779
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/bytecode/MatchKeysNode.java

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* The Universal Permissive License (UPL), Version 1.0
@@ -62,27 +62,24 @@
6262
public abstract class MatchKeysNode extends PNodeWithContext {
6363
public abstract Object execute(Frame frame, Object map, Object[] keys);
6464

65-
@Specialization(guards = {"keys.length > 0", "keys.length <= 32"})
66-
static Object match(VirtualFrame frame, Object map, @NeverDefault Object[] keysArg,
67-
@Cached(value = "keysArg", dimensions = 1) Object[] keys,
65+
@Specialization(guards = {"keys.length == keysLen", "keysLen > 0", "keysLen <= 32"}, limit = "1")
66+
static Object matchCached(VirtualFrame frame, Object map, @NeverDefault Object[] keys,
67+
@Cached("keys.length") int keysLen,
6868
@Cached PyObjectRichCompareBool.EqNode compareNode,
6969
@Cached PyObjectCallMethodObjArgs callMethod,
7070
@Cached PythonObjectFactory factory,
7171
@Cached PRaiseNode raise) {
72-
if (keys.length == 0) {
73-
return factory.createTuple(PythonUtils.EMPTY_OBJECT_ARRAY);
74-
}
75-
Object[] values = getValues(frame, map, keys, compareNode, callMethod, raise);
72+
Object[] values = getValues(frame, map, keys, keysLen, compareNode, callMethod, raise);
7673
return values != null ? factory.createTuple(values) : PNone.NONE;
7774
}
7875

7976
@ExplodeLoop
80-
private static Object[] getValues(VirtualFrame frame, Object map, Object[] keys, PyObjectRichCompareBool.EqNode compareNode, PyObjectCallMethodObjArgs callMethod, PRaiseNode raise) {
81-
CompilerAsserts.partialEvaluationConstant(keys);
82-
Object[] values = new Object[keys.length];
77+
private static Object[] getValues(VirtualFrame frame, Object map, Object[] keys, int keysLen, PyObjectRichCompareBool.EqNode compareNode, PyObjectCallMethodObjArgs callMethod, PRaiseNode raise) {
78+
CompilerAsserts.partialEvaluationConstant(keysLen);
79+
Object[] values = new Object[keysLen];
8380
Object dummy = new Object();
84-
Object[] seen = new Object[keys.length];
85-
for (int i = 0; i < keys.length; i++) {
81+
Object[] seen = new Object[keysLen];
82+
for (int i = 0; i < values.length; i++) {
8683
Object key = keys[i];
8784
checkSeen(frame, raise, seen, key, compareNode);
8885
seen[i] = key;
@@ -104,7 +101,7 @@ private static void checkSeen(VirtualFrame frame, PRaiseNode raise, Object[] see
104101
}
105102
}
106103

107-
@Specialization(guards = "keys.length > 32")
104+
@Specialization(guards = "keys.length > 0", replaces = "matchCached")
108105
static Object match(VirtualFrame frame, Object map, Object[] keys,
109106
@Cached PyObjectRichCompareBool.EqNode compareNode,
110107
@Cached PyObjectCallMethodObjArgs callMethod,
@@ -118,7 +115,6 @@ static Object match(VirtualFrame frame, Object map, Object[] keys,
118115
}
119116

120117
private static Object[] getValuesLongArray(VirtualFrame frame, Object map, Object[] keys, PyObjectRichCompareBool.EqNode compareNode, PyObjectCallMethodObjArgs callMethod, PRaiseNode raise) {
121-
CompilerAsserts.partialEvaluationConstant(keys);
122118
Object[] values = new Object[keys.length];
123119
Object dummy = new Object();
124120
Object[] seen = new Object[keys.length];

0 commit comments

Comments
 (0)