Skip to content
This repository was archived by the owner on Jan 13, 2026. It is now read-only.

Commit a4c9587

Browse files
authored
Add support for autoformatting comprehensions (#99)
* Add support for autoformatting set,dict,list comprehensions * Convert match case to ensure python 3.9 compatibility * Refactor generator comprehension and add extra tests * Add extra comments * Format
1 parent a2a19ff commit a4c9587

File tree

2 files changed

+159
-1
lines changed

2 files changed

+159
-1
lines changed

rewrite/rewrite/python/format/spaces_visitor.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
MethodDeclaration, Empty, ArrayAccess, Space, If, Block, ClassDeclaration, VariableDeclarations, JRightPadded, \
77
Import
88
from rewrite.python import PythonVisitor, SpacesStyle, Binary, ChainedAssignment, Slice, CollectionLiteral, \
9-
ForLoop, DictLiteral, KeyValue, TypeHint, MultiImport, ExpressionTypeTree
9+
ForLoop, DictLiteral, KeyValue, TypeHint, MultiImport, ExpressionTypeTree, ComprehensionExpression
1010
from rewrite.visitor import P
1111

1212

@@ -408,6 +408,43 @@ def visit_expression_type_tree(self, expression_type_tree: ExpressionTypeTree, p
408408
ett = space_before(ett, False)
409409
return ett
410410

411+
def visit_comprehension_expression(self, comprehension_expression: ComprehensionExpression, p: P) -> J:
412+
ce = cast(ComprehensionExpression, super().visit_comprehension_expression(comprehension_expression, p))
413+
414+
# Handle space before result this will depend on the style setting for the comprehension type.
415+
if ce.kind == ComprehensionExpression.Kind.LIST:
416+
ce = ce.with_result(space_before(ce.result, self._style.within.brackets))
417+
ce = ce.with_suffix(update_space(ce.suffix, self._style.within.brackets))
418+
elif ce.kind == ComprehensionExpression.Kind.GENERATOR:
419+
ce = ce.with_result(space_before(ce.result, False))
420+
ce = ce.with_suffix(update_space(ce.suffix, False))
421+
elif ce.kind in (ComprehensionExpression.Kind.SET, ComprehensionExpression.Kind.DICT):
422+
ce = ce.with_result(space_before(ce.result, self._style.within.braces))
423+
ce = ce.with_suffix(update_space(ce.suffix, self._style.within.braces))
424+
425+
return ce
426+
427+
def visit_comprehension_condition(self, condition: ComprehensionExpression.Condition, p: P) -> J:
428+
cond = cast(ComprehensionExpression.Condition, super().visit_comprehension_condition(condition, p))
429+
# Set single space before and after comprehension 'if' keyword.
430+
cond = space_before(cond, True)
431+
cond = cond.with_expression(space_before(cond.expression, True))
432+
return cond
433+
434+
def visit_comprehension_clause(self, clause: ComprehensionExpression.Clause, p: P) -> J:
435+
cc = cast(ComprehensionExpression.Clause, super().visit_comprehension_clause(clause, p))
436+
437+
# Ensure single space before 'for' keyword
438+
cc = space_before(cc, True)
439+
440+
# Single before 'in' keyword e.g. ..i in... <-> ...i in...
441+
cc = cc.padding.with_iterated_list(space_before_left_padded(cc.padding.iterated_list, True))
442+
# Single space before 'iterator' variable (or after for keyword) e.g. ...for i <-> ...for i
443+
cc = cc.with_iterator_variable(space_before(cc.iterator_variable, True))
444+
# Ensure single space after 'in' keyword e.g. ...in range(10) <-> ...in range(10)
445+
cc = cc.padding.with_iterated_list(space_before_left_padded_element(cc.padding.iterated_list, True))
446+
return cc
447+
411448
def _remap_trailing_comma_space(self, tc: j.TrailingComma) -> j.TrailingComma:
412449
return tc.with_suffix(update_space(tc.suffix, self._style.other.after_comma))
413450

@@ -438,6 +475,9 @@ def space_before_container(container: j.JContainer, add_space: bool) -> j.JConta
438475
return container
439476

440477

478+
def space_before_left_padded_element(container: j.JLeftPadded, add_space: bool) -> j.JLeftPadded:
479+
return container.with_element(space_before(container.element, add_space))
480+
441481
def space_before_right_padded_element(container: j.JRightPadded, add_space: bool) -> j.JRightPadded:
442482
return container.with_element(space_before(container.element, add_space))
443483

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import pytest
2+
3+
from rewrite.python import IntelliJ, SpacesVisitor
4+
from rewrite.test import rewrite_run, python, RecipeSpec, from_visitor
5+
6+
7+
@pytest.mark.parametrize("within_brackets", [False, True])
8+
def test_spaces_with_list_comprehension(within_brackets):
9+
style = IntelliJ.spaces()
10+
style = style.with_within(
11+
style.within.with_brackets(within_brackets)
12+
)
13+
_s = " " if within_brackets else ""
14+
rewrite_run(
15+
# language=python
16+
python(
17+
"""\
18+
a = [ i*2 for i in range(0, 10)]
19+
a = [ i*2 for i in [1, 2, 3 ]]
20+
""",
21+
f"""\
22+
a = [i * 2 for i in range(0, 10)]
23+
a = [i * 2 for i in [1, 2, 3]]
24+
""".replace("[", "[" + _s).replace("]", _s + "]")
25+
),
26+
spec=RecipeSpec()
27+
.with_recipe(from_visitor(SpacesVisitor(style)))
28+
)
29+
30+
31+
def test_spaces_with_generator_comprehension():
32+
style = IntelliJ.spaces()
33+
34+
rewrite_run(
35+
# language=python
36+
python(
37+
"""\
38+
a = ( i*2 for i in range(0, 10))
39+
""",
40+
f"""\
41+
a = (i * 2 for i in range(0, 10))
42+
"""
43+
),
44+
spec=RecipeSpec()
45+
.with_recipe(from_visitor(SpacesVisitor(style)))
46+
)
47+
48+
49+
@pytest.mark.parametrize("within_brackets", [False, True])
50+
def test_spaces_with_list_comprehension_with_condition(within_brackets):
51+
style = IntelliJ.spaces()
52+
style = style.with_within(
53+
style.within.with_brackets(within_brackets)
54+
)
55+
_s = " " if within_brackets else ""
56+
rewrite_run(
57+
# language=python
58+
python(
59+
"""\
60+
a = [ i* 2 for i in range(0, 10) if i % 2 == 0 ]
61+
""",
62+
"""\
63+
a = [i * 2 for i in range(0, 10) if i % 2 == 0]
64+
""".replace("[", "[" + _s).replace("]", _s + "]")
65+
),
66+
spec=RecipeSpec()
67+
.with_recipe(from_visitor(SpacesVisitor(style)))
68+
)
69+
70+
71+
@pytest.mark.parametrize("within_braces", [False, True])
72+
def test_spaces_with_set_comprehension(within_braces):
73+
style = IntelliJ.spaces()
74+
style = style.with_within(
75+
style.within.with_braces(within_braces)
76+
)
77+
_s = " " if within_braces else ""
78+
rewrite_run(
79+
# language=python
80+
python(
81+
"""\
82+
a = {i*2 for i in range(0, 10)}
83+
a = {i for i in {1, 2, 3 }}
84+
""",
85+
"""\
86+
a = {i * 2 for i in range(0, 10)}
87+
a = {i for i in {1, 2, 3}}
88+
""".replace("{", "{" + _s).replace("}", _s + "}")
89+
),
90+
spec=RecipeSpec()
91+
.with_recipe(from_visitor(SpacesVisitor(style)))
92+
)
93+
94+
95+
@pytest.mark.parametrize("within_braces", [False, True])
96+
def test_spaces_with_dict_comprehension(within_braces):
97+
style = IntelliJ.spaces()
98+
style = style.with_within(
99+
style.within.with_braces(within_braces)
100+
)
101+
_s = " " if within_braces else ""
102+
rewrite_run(
103+
# language=python
104+
python(
105+
"""\
106+
a = {i: i*2 for i in range(0, 10)}
107+
a = {i: i for i in [1, 2, 3]}
108+
a = {k: v*2 for k,v in { "a": 2, "b": 4}.items( ) }
109+
""",
110+
"""\
111+
a = {i: i * 2 for i in range(0, 10)}
112+
a = {i: i for i in [1, 2, 3]}
113+
a = {k: v * 2 for k, v in {"a": 2, "b": 4}.items()}
114+
""".replace("{", "{" + _s).replace("}", _s + "}")
115+
),
116+
spec=RecipeSpec()
117+
.with_recipe(from_visitor(SpacesVisitor(style)))
118+
)

0 commit comments

Comments
 (0)