Skip to content

Commit 4bab9b4

Browse files
Python: Fix MatchSequence parser incorrectly consuming child's delimiter (#6713)
* Fix MatchSequence parser incorrectly consuming child's delimiter When parsing `case [c], _:`, the outer implicit tuple MatchSequence was greedily consuming the `[` that belongs to the inner child MatchSequence. This caused the pattern to be parsed as a single bracketed list `[c, _]` instead of an implicit tuple containing `[c]` and `_`. The fix adds `__is_own_sequence_delimiter` which checks whether a `[` or `(` delimiter belongs to this MatchSequence or to its first child by combining AST inspection (is the first child also a MatchSequence?) with token peeking (are there consecutive delimiters?). Also includes RPC server improvements: relativeTo path support for parse inputs, bare string PathInput handling, and better error logging. * Remove newly introduced logger.info calls --------- Co-authored-by: Tim te Beek <tim@moderne.io>
1 parent 485764e commit 4bab9b4

File tree

3 files changed

+164
-19
lines changed

3 files changed

+164
-19
lines changed

rewrite-python/rewrite/src/rewrite/python/_parser_visitor.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,10 +1208,12 @@ def visit_MatchValue(self, node):
12081208
def visit_MatchSequence(self, node):
12091209
prefix = self.__whitespace()
12101210
end_delim = None
1211-
if self.__skip('['):
1211+
if self.__at_token('[') and self.__is_own_sequence_delimiter(node, '['):
1212+
self.__skip('[')
12121213
kind = py.MatchCase.Pattern.Kind.SEQUENCE_LIST
12131214
end_delim = ']'
1214-
elif self.__skip('('):
1215+
elif self.__at_token('(') and self.__is_own_sequence_delimiter(node, '('):
1216+
self.__skip('(')
12151217
kind = py.MatchCase.Pattern.Kind.SEQUENCE_TUPLE
12161218
end_delim = ')'
12171219
else:
@@ -3321,3 +3323,26 @@ def __at_token(self, s: str) -> bool:
33213323
if self._token_idx >= len(self._tokens):
33223324
return False
33233325
return self._tokens[self._token_idx].string == s
3326+
3327+
def __is_own_sequence_delimiter(self, node, delim: str) -> bool:
3328+
"""Check if the delimiter at the current token belongs to this MatchSequence.
3329+
3330+
When the current token is '[' (or '('), it could belong to this
3331+
sequence or to its first child (e.g., ``[c], _`` vs ``[c, _]``).
3332+
3333+
If the first child pattern is itself a MatchSequence, the delimiter
3334+
might belong to the child. We disambiguate by peeking at the next
3335+
token: if it is also a delimiter (``[`` or ``(``), the current one
3336+
opens this sequence (e.g., ``[[a], b]``); otherwise the current
3337+
delimiter belongs to the child (e.g., ``[c], _``).
3338+
"""
3339+
import ast as stdlib_ast
3340+
if node.patterns and isinstance(node.patterns[0], stdlib_ast.MatchSequence):
3341+
# The first child is also a sequence — check whether there are
3342+
# two consecutive delimiters, meaning the outer one is ours.
3343+
next_idx = self._token_idx + 1
3344+
if next_idx < len(self._tokens):
3345+
next_tok = self._tokens[next_idx].string
3346+
return next_tok in ('[', '(')
3347+
return False
3348+
return True

rewrite-python/rewrite/src/rewrite/rpc/server.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -209,20 +209,28 @@ def generate_id() -> str:
209209
return str(uuid4())
210210

211211

212-
def parse_python_file(path: str) -> dict:
212+
def parse_python_file(path: str, relative_to: Optional[str] = None) -> dict:
213213
"""Parse a Python file and return its LST."""
214214
with open(path, 'r', encoding='utf-8') as f:
215215
source = f.read()
216-
return parse_python_source(source, path)
216+
return parse_python_source(source, path, relative_to)
217217

218218

219-
def parse_python_source(source: str, path: str = "<unknown>") -> dict:
219+
def parse_python_source(source: str, path: str = "<unknown>", relative_to: Optional[str] = None) -> dict:
220220
"""Parse Python source code and return its LST.
221221
222222
The parser used depends on the REWRITE_PYTHON_VERSION environment variable:
223223
- "2" or "2.7": Use parso-based Py2ParserVisitor for Python 2 code
224224
- "3" (default): Use ast-based ParserVisitor for Python 3 code
225225
"""
226+
# Compute the source_path that will be stored on the LST
227+
source_path = Path(path)
228+
if relative_to is not None:
229+
try:
230+
source_path = source_path.relative_to(relative_to)
231+
except ValueError:
232+
pass # path is not under relative_to, keep absolute
233+
226234
try:
227235
from rewrite import Markers
228236

@@ -250,7 +258,7 @@ def parse_python_source(source: str, path: str = "<unknown>") -> dict:
250258
# Convert to OpenRewrite LST
251259
cu = ParserVisitor(source, path).visit(tree)
252260

253-
cu = cu.replace(source_path=Path(path))
261+
cu = cu.replace(source_path=source_path)
254262
cu = cu.replace(markers=Markers.EMPTY)
255263

256264
# Store and return
@@ -263,14 +271,14 @@ def parse_python_source(source: str, path: str = "<unknown>") -> dict:
263271
except ImportError as e:
264272
logger.error(f"Failed to import parser: {e}")
265273
traceback.print_exc()
266-
return _create_parse_error(path, str(e), source)
274+
return _create_parse_error(str(source_path), str(e), source)
267275
except SyntaxError as e:
268276
logger.error(f"Syntax error parsing {path}: {e}")
269-
return _create_parse_error(path, str(e), source)
277+
return _create_parse_error(str(source_path), str(e), source)
270278
except Exception as e:
271279
logger.error(f"Error parsing {path}: {e}")
272280
traceback.print_exc()
273-
return _create_parse_error(path, str(e), source)
281+
return _create_parse_error(str(source_path), str(e), source)
274282

275283

276284
def _create_parse_error(path: str, message: str, source: str = '') -> dict:
@@ -313,18 +321,23 @@ def _create_parse_error(path: str, message: str, source: str = '') -> dict:
313321
def handle_parse(params: dict) -> List[str]:
314322
"""Handle a Parse RPC request."""
315323
inputs = params.get('inputs', [])
324+
relative_to = params.get('relativeTo')
316325
results = []
317326

318-
for input_item in inputs:
319-
if 'path' in input_item:
320-
# File input
321-
result = parse_python_file(input_item['path'])
327+
for i, input_item in enumerate(inputs):
328+
if isinstance(input_item, str):
329+
# PathInput serialized via @JsonValue as a bare path string
330+
result = parse_python_file(input_item, relative_to)
331+
elif 'path' in input_item:
332+
# File input as dict
333+
result = parse_python_file(input_item['path'], relative_to)
322334
elif 'text' in input_item or 'source' in input_item:
323335
# String input - Java sends 'text' and 'sourcePath'
324336
source = input_item.get('text') or input_item.get('source')
325337
path = input_item.get('sourcePath') or input_item.get('relativePath', '<unknown>')
326-
result = parse_python_source(source, path)
338+
result = parse_python_source(source, path, relative_to)
327339
else:
340+
logger.warning(f" [{i}] unknown input type: {type(input_item)}")
328341
continue
329342
results.append(result['id'])
330343

@@ -388,17 +401,15 @@ def handle_get_object(params: dict) -> List[dict]:
388401

389402
q = RpcSendQueue(source_file_type)
390403
result = q.generate(obj, before)
391-
logger.debug(f"GetObject result: {len(result)} items")
392-
for i, item in enumerate(result[:10]): # Log first 10 items
393-
logger.debug(f" [{i}] {item}")
394404

395405
# Update remote_objects to track that Java now has this version
396406
remote_objects[obj_id] = obj
397407

398408
return result
399409

400-
except Exception as e:
401-
logger.error(f"Error serializing object: {e}")
410+
except BaseException as e:
411+
source_path = getattr(obj, 'source_path', None)
412+
logger.error(f"Error serializing object {obj_id} (type={type(obj).__name__}, path={source_path}): {e}")
402413
import traceback as tb
403414
tb.print_exc()
404415
return [{'state': 'END_OF_OBJECT'}]

rewrite-python/rewrite/tests/python/all/tree/match_test.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,115 @@ def f(x):
117117
))
118118

119119

120+
def test_match_with_star_wildcard_and_capture():
121+
# language=python - star wildcard followed by a capture variable
122+
RecipeSpec().rewrite_run(python(
123+
"""\
124+
match x:
125+
case [*_, stmt]:
126+
pass
127+
"""
128+
))
129+
130+
131+
def test_match_with_star_capture_and_variable():
132+
# language=python - star capture followed by a variable
133+
RecipeSpec().rewrite_run(python(
134+
"""\
135+
match x:
136+
case [*prev, stmt]:
137+
pass
138+
"""
139+
))
140+
141+
142+
def test_match_with_star_wildcard_expression_not_none():
143+
# Verify that the Star node for *_ has a non-None expression
144+
import ast as stdlib_ast
145+
from rewrite.python._parser_visitor import ParserVisitor
146+
147+
code = """\
148+
match x:
149+
case [*_, stmt]:
150+
pass
151+
"""
152+
tree = stdlib_ast.parse(code)
153+
cu = ParserVisitor(code, 'test.py').visit(tree)
154+
155+
# Walk the LST to find Star nodes
156+
stars = []
157+
_collect_stars(cu, stars)
158+
assert len(stars) > 0, "Expected at least one Star node"
159+
for star in stars:
160+
assert star.expression is not None, "Star expression should not be None for *_ pattern"
161+
162+
163+
def test_match_with_nested_star_wildcard_expression_not_none():
164+
# Verify that Star nodes in nested match-class patterns have non-None expression
165+
# This is the pattern from refurb that triggers the PythonValidator error
166+
import ast as stdlib_ast
167+
from rewrite.python._parser_visitor import ParserVisitor
168+
169+
code = """\
170+
match node:
171+
case IfStmt(else_body=Block(body=[*_, stmt])) | WithStmt(body=Block(body=[*_, stmt])):
172+
pass
173+
case ForStmt(body=Block(body=[*prev, stmt])) | WhileStmt(body=Block(body=[*prev, stmt])):
174+
pass
175+
"""
176+
tree = stdlib_ast.parse(code)
177+
cu = ParserVisitor(code, 'test.py').visit(tree)
178+
179+
stars = []
180+
_collect_stars(cu, stars)
181+
assert len(stars) > 0, f"Expected Star nodes, found none"
182+
for star in stars:
183+
assert star.expression is not None, f"Star expression should not be None"
184+
185+
186+
def _collect_stars(node, result, visited=None):
187+
"""Recursively collect Star nodes from an LST."""
188+
from rewrite.python.tree import Star
189+
190+
if visited is None:
191+
visited = set()
192+
node_id = id(node)
193+
if node_id in visited:
194+
return
195+
visited.add(node_id)
196+
197+
if isinstance(node, Star):
198+
result.append(node)
199+
200+
# Check dataclass fields
201+
if hasattr(node, '__dataclass_fields__'):
202+
for field_name in node.__dataclass_fields__:
203+
val = getattr(node, field_name, None)
204+
if val is None or isinstance(val, (str, int, float, bool, bytes)):
205+
continue
206+
if hasattr(val, '__dataclass_fields__'):
207+
_collect_stars(val, result, visited)
208+
elif isinstance(val, list):
209+
for item in val:
210+
if hasattr(item, '__dataclass_fields__'):
211+
_collect_stars(item, result, visited)
212+
elif hasattr(item, 'element'):
213+
_collect_stars(item.element, result, visited)
214+
215+
216+
def test_match_tuple_with_sequence_pattern():
217+
# language=python - implicit tuple match with sequence pattern [c] as first element
218+
RecipeSpec().rewrite_run(python(
219+
"""\
220+
match (y, z):
221+
case _, b:
222+
pass
223+
case [c], _:
224+
pass
225+
"""
226+
))
227+
228+
120229
def test_match_with_or_pattern_in_tuple():
121230
# language=python - OR pattern as first element of implicit tuple
122231
RecipeSpec().rewrite_run(python(

0 commit comments

Comments
 (0)