Skip to content

Commit 65737bd

Browse files
author
Michal Medvecky
committed
[GR-63204][GR-63581] Multiple OR patterns support, same variable name usage in "case", mapping patma support, class patma support.
PullRequest: graalpython/3764
2 parents f6c3764 + 6a15c45 commit 65737bd

File tree

6 files changed

+1045
-72
lines changed

6 files changed

+1045
-72
lines changed

graalpython/com.oracle.graal.python.test/src/tests/cpyext/test_abstract.py

Lines changed: 50 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1567,29 +1567,28 @@ def _reference_seq_repeat(args):
15671567
case _:
15681568
return args[0] * args[1]
15691569

1570-
if not os.environ.get('BYTECODE_DSL_INTERPRETER'): # TODO: class pattern matching
1571-
test_PySequence_Repeat = CPyExtFunction(
1572-
_reference_seq_repeat,
1573-
lambda: (
1574-
((1,), 0),
1575-
((1,), 1),
1576-
((1,), 3),
1577-
([1], 0),
1578-
([1], 1),
1579-
([1], 3),
1580-
("hello", 0),
1581-
("hello", 1),
1582-
("hello", 3),
1583-
({}, 0),
1584-
(SeqWithMulAdd(), 42),
1585-
(NonSeqWithMulAdd(), 24),
1586-
(DictSubclassWithSequenceMethods(), 5),
1587-
),
1588-
resultspec="O",
1589-
argspec='On',
1590-
arguments=["PyObject* obj", "Py_ssize_t n"],
1591-
cmpfunc=unhandled_error_compare
1592-
)
1570+
test_PySequence_Repeat = CPyExtFunction(
1571+
_reference_seq_repeat,
1572+
lambda: (
1573+
((1,), 0),
1574+
((1,), 1),
1575+
((1,), 3),
1576+
([1], 0),
1577+
([1], 1),
1578+
([1], 3),
1579+
("hello", 0),
1580+
("hello", 1),
1581+
("hello", 3),
1582+
({}, 0),
1583+
(SeqWithMulAdd(), 42),
1584+
(NonSeqWithMulAdd(), 24),
1585+
(DictSubclassWithSequenceMethods(), 5),
1586+
),
1587+
resultspec="O",
1588+
argspec='On',
1589+
arguments=["PyObject* obj", "Py_ssize_t n"],
1590+
cmpfunc=unhandled_error_compare
1591+
)
15931592

15941593
test_PySequence_InPlaceRepeat = CPyExtFunction(
15951594
lambda args: args[0] * args[1],
@@ -1622,35 +1621,34 @@ def _reference_seq_concat(args):
16221621
case _:
16231622
return args[0] + args[1]
16241623

1625-
if not os.environ.get('BYTECODE_DSL_INTERPRETER'): # TODO: class pattern matching
1626-
test_PySequence_Concat = CPyExtFunction(
1627-
_reference_seq_concat,
1628-
lambda: (
1629-
((1,), tuple()),
1630-
((1,), list()),
1631-
((1,), (2,)),
1632-
((1,), [2,]),
1633-
([1], tuple()),
1634-
([1], list()),
1635-
([1], (2,)),
1636-
([1], [2,]),
1637-
("hello", "world"),
1638-
("hello", ""),
1639-
({}, []),
1640-
([], {}),
1641-
(SeqWithMulAdd(), 1),
1642-
(SeqWithMulAdd(), SeqWithMulAdd()),
1643-
(SeqWithMulAdd(), [1,2,3]),
1644-
(NonSeqWithMulAdd(), 2),
1645-
(NonSeqWithMulAdd(), [1,2,3]),
1646-
(DictSubclassWithSequenceMethods(), (1,2,3)),
1647-
((1,2,3), DictSubclassWithSequenceMethods()),
1648-
),
1649-
resultspec="O",
1650-
argspec='OO',
1651-
arguments=["PyObject* s", "PyObject* o"],
1652-
cmpfunc=unhandled_error_compare
1653-
)
1624+
test_PySequence_Concat = CPyExtFunction(
1625+
_reference_seq_concat,
1626+
lambda: (
1627+
((1,), tuple()),
1628+
((1,), list()),
1629+
((1,), (2,)),
1630+
((1,), [2,]),
1631+
([1], tuple()),
1632+
([1], list()),
1633+
([1], (2,)),
1634+
([1], [2,]),
1635+
("hello", "world"),
1636+
("hello", ""),
1637+
({}, []),
1638+
([], {}),
1639+
(SeqWithMulAdd(), 1),
1640+
(SeqWithMulAdd(), SeqWithMulAdd()),
1641+
(SeqWithMulAdd(), [1,2,3]),
1642+
(NonSeqWithMulAdd(), 2),
1643+
(NonSeqWithMulAdd(), [1,2,3]),
1644+
(DictSubclassWithSequenceMethods(), (1,2,3)),
1645+
((1,2,3), DictSubclassWithSequenceMethods()),
1646+
),
1647+
resultspec="O",
1648+
argspec='OO',
1649+
arguments=["PyObject* s", "PyObject* o"],
1650+
cmpfunc=unhandled_error_compare
1651+
)
16541652

16551653
test_PySequence_InPlaceConcat = CPyExtFunction(
16561654
lambda args: args[0] + list(args[1]) if isinstance(args[0], list) else args[0] + args[1],

graalpython/com.oracle.graal.python.test/src/tests/test_patmat.py

Lines changed: 143 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@
3838
# SOFTWARE.
3939
import os
4040
import sys, ast, unittest
41+
import inspect
4142

4243

43-
@unittest.skipIf(sys.version_info.minor < 10, "Requires Python 3.10+")
4444
def test_guard():
4545
def f(x, g):
4646
match x:
@@ -62,7 +62,6 @@ def f(x):
6262
assert f(1) == 42
6363
assert f(2) == 0
6464

65-
@unittest.skipIf(sys.version_info.minor < 10, "Requires Python 3.10+")
6665
def test_complex_as_binary_op():
6766
src = """
6867
def f(a):
@@ -87,8 +86,6 @@ def f(a):
8786
assert f(6+3j) == "match add"
8887
assert f(-2-3j) == "match sub"
8988

90-
@unittest.skipIf(sys.version_info.minor < 10, "Requires Python 3.10+")
91-
@unittest.skipIf(os.environ.get('BYTECODE_DSL_INTERPRETER'), "TODO: mapping pattern matching")
9289
def test_long_mapping():
9390
def f(x):
9491
match d:
@@ -107,7 +104,6 @@ def star_match(x):
107104

108105
assert star_match(d) == {33:33}
109106

110-
@unittest.skipIf(os.environ.get('BYTECODE_DSL_INTERPRETER'), "TODO: mapping pattern matching")
111107
def test_mutable_dict_keys():
112108
class MyObj:
113109
pass
@@ -125,3 +121,145 @@ def test(name):
125121

126122
assert test('attr1') == {'dyn_match': 1, 'attr2': 2, 'attr3': 3}
127123
assert test('attr2') == {'dyn_match': 2, 'attr1': 1, 'attr3': 3}
124+
125+
def test_multiple_or_pattern_basic():
126+
match 0:
127+
case 0 | 1 | 2 | 3 | 4 | 5 as x:
128+
assert x == 0
129+
130+
match 3:
131+
case ((0 | 1 | 2) as x) | ((3 | 4 | 5) as x):
132+
assert x == 3
133+
134+
def test_sequence_pattern():
135+
match (1, 2):
136+
case (3, 2):
137+
assert False
138+
139+
match (1, (2, 2)):
140+
case (3, (2, 2)):
141+
assert False
142+
143+
match (1, 2):
144+
case (3, q):
145+
assert False
146+
147+
148+
def test_multiple_or_pattern_advanced():
149+
match 4:
150+
case (0 as z) | (1 as z) | (2 as z) | (4 as z) | (77 as z):
151+
assert z == 4
152+
153+
match 42:
154+
case (0 as z) | (1 as z):
155+
assert z == 1
156+
case x:
157+
assert x == 42
158+
159+
match 2:
160+
case (0 as z) | (1 as z) | (2 as z):
161+
assert z == 2
162+
case _:
163+
assert False
164+
165+
match 1:
166+
case (0 as z) | (1 as z) | (2 as z):
167+
assert z == 1
168+
case _:
169+
assert False
170+
171+
match 0:
172+
case (0 as z) | (1 as z) | (2 as z):
173+
assert z == 0
174+
case _:
175+
assert False
176+
177+
match (1, 2):
178+
case (w, 2) | (2, w):
179+
assert w == 1
180+
181+
182+
def test_multiple_or_pattern_creates_locals():
183+
match (1, 2):
184+
case (a, 1) | (a, 2):
185+
assert a == 1
186+
assert a == 1
187+
188+
match [1, 2]:
189+
case [a1, 1] | [a1, 2]:
190+
assert a1 == 1
191+
assert a1 == 1
192+
193+
match (1, 2, 2, 3, 2):
194+
case (1, a, b, 4, c) | (1, a, b, 3, c) | (1, a, b, 2, c):
195+
assert a == 2
196+
assert b == 2
197+
assert c == 2
198+
assert a == 2
199+
assert b == 2
200+
assert c == 2
201+
202+
match (1, 3, 4, 9):
203+
case ((d, e, f, 7) | (d, e, f, 8) | (d, e, f, 6) | (d, e, f, 9)):
204+
assert d == 1
205+
assert e == 3
206+
assert f == 4
207+
assert d == 1
208+
assert e == 3
209+
assert f == 4
210+
211+
match (1,2,3,4,5,6,7):
212+
case (0,q,w,e,r,t,y) | (q,w,e,r,t,y,7):
213+
assert q == 1
214+
assert w == 2
215+
assert e == 3
216+
assert r == 4
217+
assert t == 5
218+
assert y == 6
219+
assert q == 1
220+
assert w == 2
221+
assert e == 3
222+
assert r == 4
223+
assert t == 5
224+
assert y == 6
225+
226+
227+
class TestErrors(unittest.TestCase):
228+
def assert_syntax_error(self, code: str):
229+
with self.assertRaises(SyntaxError):
230+
compile(inspect.cleandoc(code), "<test>", "exec")
231+
232+
def test_alternative_patterns_bind_different_names_0(self):
233+
self.assert_syntax_error("""
234+
match ...:
235+
case "a" | a:
236+
pass
237+
""")
238+
239+
def test_alternative_patterns_bind_different_names_1(self):
240+
self.assert_syntax_error("""
241+
match ...:
242+
case [a, [b] | [c] | [d]]:
243+
pass
244+
""")
245+
246+
def test_multiple_or_same_name(self):
247+
self.assert_syntax_error("""
248+
match 0:
249+
case x | x:
250+
pass
251+
""")
252+
253+
def test_multiple_or_wildcard(self):
254+
self.assert_syntax_error("""
255+
match 0:
256+
case * | 1:
257+
pass
258+
""")
259+
260+
def test_unbound_local_variable(self):
261+
with self.assertRaises(UnboundLocalError):
262+
match (1, 3):
263+
case (a, 1) | (a, 2):
264+
pass
265+
assert a == 1

0 commit comments

Comments
 (0)