Skip to content

Commit ce0a32a

Browse files
authored
[flake8-comprehensions] Handle template strings for comprehension fixes (astral-sh#18710)
Essentially this PR ensures that when we do fixes like this: ```diff - t"{set(f(x) for x in foo)}" + t"{ {f(x) for x in foo} }" ``` we are correctly adding whitespace around the braces. This logic is already in place for f-strings and just needed to be generalized to interpolated strings.
1 parent 10a1d9f commit ce0a32a

14 files changed

+919
-23
lines changed

crates/ruff_linter/resources/test/fixtures/flake8_comprehensions/C401.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,19 @@ def f(x):
3636
# some more
3737
)
3838

39+
# t-strings
40+
print(t"Hello {set(f(a) for a in 'abc')} World")
41+
print(t"Hello { set(f(a) for a in 'abc') } World")
42+
small_nums = t"{set(a if a < 6 else 0 for a in range(3))}"
43+
print(t"Hello {set(a for a in range(3))} World")
44+
print(t"{set(a for a in 'abc') - set(a for a in 'ab')}")
45+
print(t"{ set(a for a in 'abc') - set(a for a in 'ab') }")
46+
47+
3948
# Not built-in set.
4049
def set(*args, **kwargs):
4150
return None
4251

4352
set(2 * x for x in range(3))
4453
set(x for x in range(3))
54+

crates/ruff_linter/resources/test/fixtures/flake8_comprehensions/C403.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,16 @@ def f(x):
3434
))))
3535

3636
# Test trailing comma case
37-
s = set([x for x in range(3)],)
37+
s = set([x for x in range(3)],)
38+
39+
s = t"{set([x for x in 'ab'])}"
40+
s = t'{set([x for x in "ab"])}'
41+
42+
def f(x):
43+
return x
44+
45+
s = t"{set([f(x) for x in 'ab'])}"
46+
47+
s = t"{ set([x for x in 'ab']) | set([x for x in 'ab']) }"
48+
s = t"{set([x for x in 'ab']) | set([x for x in 'ab'])}"
49+

crates/ruff_linter/resources/test/fixtures/flake8_comprehensions/C405.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,12 @@
2424
f"{ set(['a', 'b']) - set(['a']) }"
2525
f"a {set(['a', 'b']) - set(['a'])} b"
2626
f"a { set(['a', 'b']) - set(['a']) } b"
27+
28+
t"{set([1,2,3])}"
29+
t"{set(['a', 'b'])}"
30+
t'{set(["a", "b"])}'
31+
32+
t"{set(['a', 'b']) - set(['a'])}"
33+
t"{ set(['a', 'b']) - set(['a']) }"
34+
t"a {set(['a', 'b']) - set(['a'])} b"
35+
t"a { set(['a', 'b']) - set(['a']) } b"

crates/ruff_linter/resources/test/fixtures/flake8_comprehensions/C408.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,13 @@ def list():
2727

2828
tuple( # comment
2929
)
30+
31+
t"{dict(x='y')}"
32+
t'{dict(x="y")}'
33+
t"{dict()}"
34+
t"a {dict()} b"
35+
36+
t"{dict(x='y') | dict(y='z')}"
37+
t"{ dict(x='y') | dict(y='z') }"
38+
t"a {dict(x='y') | dict(y='z')} b"
39+
t"a { dict(x='y') | dict(y='z') } b"

crates/ruff_linter/resources/test/fixtures/flake8_comprehensions/C417.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,8 @@ def func(arg1: int, arg2: int = 4):
7070
list(map(lambda x, y: x, [(1, 2), (3, 4)]))
7171
list(map(lambda: 1, "xyz"))
7272
list(map(lambda x, y: x, [(1, 2), (3, 4)]))
73+
74+
# When inside t-string, then the fix should be surrounded by whitespace
75+
_ = t"{set(map(lambda x: x % 2 == 0, nums))}"
76+
_ = t"{dict(map(lambda v: (v, v**2), nums))}"
77+

crates/ruff_linter/src/checkers/ast/mod.rs

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ use ruff_python_ast::str::Quote;
3737
use ruff_python_ast::visitor::{Visitor, walk_except_handler, walk_pattern};
3838
use ruff_python_ast::{
3939
self as ast, AnyParameterRef, ArgOrKeyword, Comprehension, ElifElseClause, ExceptHandler, Expr,
40-
ExprContext, InterpolatedStringElement, Keyword, MatchCase, ModModule, Parameter, Parameters,
41-
Pattern, PythonVersion, Stmt, Suite, UnaryOp,
40+
ExprContext, ExprFString, ExprTString, InterpolatedStringElement, Keyword, MatchCase,
41+
ModModule, Parameter, Parameters, Pattern, PythonVersion, Stmt, Suite, UnaryOp,
4242
};
4343
use ruff_python_ast::{PySourceType, helpers, str, visitor};
4444
use ruff_python_codegen::{Generator, Stylist};
@@ -323,7 +323,8 @@ impl<'a> Checker<'a> {
323323
/// Return the preferred quote for a generated `StringLiteral` node, given where we are in the
324324
/// AST.
325325
fn preferred_quote(&self) -> Quote {
326-
self.f_string_quote_style().unwrap_or(self.stylist.quote())
326+
self.interpolated_string_quote_style()
327+
.unwrap_or(self.stylist.quote())
327328
}
328329

329330
/// Return the default string flags a generated `StringLiteral` node should use, given where we
@@ -345,21 +346,27 @@ impl<'a> Checker<'a> {
345346
ast::FStringFlags::empty().with_quote_style(self.preferred_quote())
346347
}
347348

348-
/// Returns the appropriate quoting for f-string by reversing the one used outside of
349-
/// the f-string.
349+
/// Returns the appropriate quoting for interpolated strings by reversing the one used outside of
350+
/// the interpolated string.
350351
///
351-
/// If the current expression in the context is not an f-string, returns ``None``.
352-
pub(crate) fn f_string_quote_style(&self) -> Option<Quote> {
353-
if !self.semantic.in_f_string() {
352+
/// If the current expression in the context is not an interpolated string, returns ``None``.
353+
pub(crate) fn interpolated_string_quote_style(&self) -> Option<Quote> {
354+
if !self.semantic.in_interpolated_string() {
354355
return None;
355356
}
356357

357-
// Find the quote character used to start the containing f-string.
358-
let ast::ExprFString { value, .. } = self
359-
.semantic
358+
// Find the quote character used to start the containing interpolated string.
359+
self.semantic
360360
.current_expressions()
361-
.find_map(|expr| expr.as_f_string_expr())?;
362-
Some(value.iter().next()?.quote_style().opposite())
361+
.find_map(|expr| match expr {
362+
Expr::FString(ExprFString { value, .. }) => {
363+
Some(value.iter().next()?.quote_style().opposite())
364+
}
365+
Expr::TString(ExprTString { value, .. }) => {
366+
Some(value.iter().next()?.quote_style().opposite())
367+
}
368+
_ => None,
369+
})
363370
}
364371

365372
/// Returns the [`SourceRow`] for the given offset.

crates/ruff_linter/src/rules/flake8_comprehensions/fixes.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,9 @@ pub(crate) fn fix_unnecessary_collection_call(
236236
// below.
237237
let mut arena: Vec<String> = vec![];
238238

239-
let quote = checker.f_string_quote_style().unwrap_or(stylist.quote());
239+
let quote = checker
240+
.interpolated_string_quote_style()
241+
.unwrap_or(stylist.quote());
240242

241243
// Quote each argument.
242244
for arg in &call.args {
@@ -317,7 +319,7 @@ pub(crate) fn pad_expression(
317319
locator: &Locator,
318320
semantic: &SemanticModel,
319321
) -> String {
320-
if !semantic.in_f_string() {
322+
if !semantic.in_interpolated_string() {
321323
return content;
322324
}
323325

@@ -349,7 +351,7 @@ pub(crate) fn pad_start(
349351
locator: &Locator,
350352
semantic: &SemanticModel,
351353
) -> String {
352-
if !semantic.in_f_string() {
354+
if !semantic.in_interpolated_string() {
353355
return content.into();
354356
}
355357

@@ -370,7 +372,7 @@ pub(crate) fn pad_end(
370372
locator: &Locator,
371373
semantic: &SemanticModel,
372374
) -> String {
373-
if !semantic.in_f_string() {
375+
if !semantic.in_interpolated_string() {
374376
return content.into();
375377
}
376378

@@ -798,10 +800,10 @@ pub(crate) fn fix_unnecessary_map(
798800

799801
let mut content = tree.codegen_stylist(stylist);
800802

801-
// If the expression is embedded in an f-string, surround it with spaces to avoid
803+
// If the expression is embedded in an interpolated string, surround it with spaces to avoid
802804
// syntax errors.
803805
if matches!(object_type, ObjectType::Set | ObjectType::Dict) {
804-
if parent.is_some_and(Expr::is_f_string_expr) {
806+
if parent.is_some_and(|expr| expr.is_f_string_expr() || expr.is_t_string_expr()) {
805807
content = format!(" {content} ");
806808
}
807809
}

crates/ruff_linter/src/rules/flake8_comprehensions/snapshots/ruff_linter__rules__flake8_comprehensions__tests__C401_C401.py.snap

Lines changed: 164 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ C401.py:32:1: C401 [*] Unnecessary generator (rewrite as a set comprehension)
346346
37 | | )
347347
| |_^ C401
348348
38 |
349-
39 | # Not built-in set.
349+
39 | # t-strings
350350
|
351351
= help: Rewrite as a set comprehension
352352

@@ -364,5 +364,166 @@ C401.py:32:1: C401 [*] Unnecessary generator (rewrite as a set comprehension)
364364
37 |-)
365365
35 |+ }
366366
38 36 |
367-
39 37 | # Not built-in set.
368-
40 38 | def set(*args, **kwargs):
367+
39 37 | # t-strings
368+
40 38 | print(t"Hello {set(f(a) for a in 'abc')} World")
369+
370+
C401.py:40:16: C401 [*] Unnecessary generator (rewrite as a set comprehension)
371+
|
372+
39 | # t-strings
373+
40 | print(t"Hello {set(f(a) for a in 'abc')} World")
374+
| ^^^^^^^^^^^^^^^^^^^^^^^^ C401
375+
41 | print(t"Hello { set(f(a) for a in 'abc') } World")
376+
42 | small_nums = t"{set(a if a < 6 else 0 for a in range(3))}"
377+
|
378+
= help: Rewrite as a set comprehension
379+
380+
Unsafe fix
381+
37 37 | )
382+
38 38 |
383+
39 39 | # t-strings
384+
40 |-print(t"Hello {set(f(a) for a in 'abc')} World")
385+
40 |+print(t"Hello { {f(a) for a in 'abc'} } World")
386+
41 41 | print(t"Hello { set(f(a) for a in 'abc') } World")
387+
42 42 | small_nums = t"{set(a if a < 6 else 0 for a in range(3))}"
388+
43 43 | print(t"Hello {set(a for a in range(3))} World")
389+
390+
C401.py:41:17: C401 [*] Unnecessary generator (rewrite as a set comprehension)
391+
|
392+
39 | # t-strings
393+
40 | print(t"Hello {set(f(a) for a in 'abc')} World")
394+
41 | print(t"Hello { set(f(a) for a in 'abc') } World")
395+
| ^^^^^^^^^^^^^^^^^^^^^^^^ C401
396+
42 | small_nums = t"{set(a if a < 6 else 0 for a in range(3))}"
397+
43 | print(t"Hello {set(a for a in range(3))} World")
398+
|
399+
= help: Rewrite as a set comprehension
400+
401+
Unsafe fix
402+
38 38 |
403+
39 39 | # t-strings
404+
40 40 | print(t"Hello {set(f(a) for a in 'abc')} World")
405+
41 |-print(t"Hello { set(f(a) for a in 'abc') } World")
406+
41 |+print(t"Hello { {f(a) for a in 'abc'} } World")
407+
42 42 | small_nums = t"{set(a if a < 6 else 0 for a in range(3))}"
408+
43 43 | print(t"Hello {set(a for a in range(3))} World")
409+
44 44 | print(t"{set(a for a in 'abc') - set(a for a in 'ab')}")
410+
411+
C401.py:42:17: C401 [*] Unnecessary generator (rewrite as a set comprehension)
412+
|
413+
40 | print(t"Hello {set(f(a) for a in 'abc')} World")
414+
41 | print(t"Hello { set(f(a) for a in 'abc') } World")
415+
42 | small_nums = t"{set(a if a < 6 else 0 for a in range(3))}"
416+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ C401
417+
43 | print(t"Hello {set(a for a in range(3))} World")
418+
44 | print(t"{set(a for a in 'abc') - set(a for a in 'ab')}")
419+
|
420+
= help: Rewrite as a set comprehension
421+
422+
Unsafe fix
423+
39 39 | # t-strings
424+
40 40 | print(t"Hello {set(f(a) for a in 'abc')} World")
425+
41 41 | print(t"Hello { set(f(a) for a in 'abc') } World")
426+
42 |-small_nums = t"{set(a if a < 6 else 0 for a in range(3))}"
427+
42 |+small_nums = t"{ {a if a < 6 else 0 for a in range(3)} }"
428+
43 43 | print(t"Hello {set(a for a in range(3))} World")
429+
44 44 | print(t"{set(a for a in 'abc') - set(a for a in 'ab')}")
430+
45 45 | print(t"{ set(a for a in 'abc') - set(a for a in 'ab') }")
431+
432+
C401.py:43:16: C401 [*] Unnecessary generator (rewrite using `set()`)
433+
|
434+
41 | print(t"Hello { set(f(a) for a in 'abc') } World")
435+
42 | small_nums = t"{set(a if a < 6 else 0 for a in range(3))}"
436+
43 | print(t"Hello {set(a for a in range(3))} World")
437+
| ^^^^^^^^^^^^^^^^^^^^^^^^ C401
438+
44 | print(t"{set(a for a in 'abc') - set(a for a in 'ab')}")
439+
45 | print(t"{ set(a for a in 'abc') - set(a for a in 'ab') }")
440+
|
441+
= help: Rewrite using `set()`
442+
443+
Unsafe fix
444+
40 40 | print(t"Hello {set(f(a) for a in 'abc')} World")
445+
41 41 | print(t"Hello { set(f(a) for a in 'abc') } World")
446+
42 42 | small_nums = t"{set(a if a < 6 else 0 for a in range(3))}"
447+
43 |-print(t"Hello {set(a for a in range(3))} World")
448+
43 |+print(t"Hello {set(range(3))} World")
449+
44 44 | print(t"{set(a for a in 'abc') - set(a for a in 'ab')}")
450+
45 45 | print(t"{ set(a for a in 'abc') - set(a for a in 'ab') }")
451+
46 46 |
452+
453+
C401.py:44:10: C401 [*] Unnecessary generator (rewrite using `set()`)
454+
|
455+
42 | small_nums = t"{set(a if a < 6 else 0 for a in range(3))}"
456+
43 | print(t"Hello {set(a for a in range(3))} World")
457+
44 | print(t"{set(a for a in 'abc') - set(a for a in 'ab')}")
458+
| ^^^^^^^^^^^^^^^^^^^^^ C401
459+
45 | print(t"{ set(a for a in 'abc') - set(a for a in 'ab') }")
460+
|
461+
= help: Rewrite using `set()`
462+
463+
Unsafe fix
464+
41 41 | print(t"Hello { set(f(a) for a in 'abc') } World")
465+
42 42 | small_nums = t"{set(a if a < 6 else 0 for a in range(3))}"
466+
43 43 | print(t"Hello {set(a for a in range(3))} World")
467+
44 |-print(t"{set(a for a in 'abc') - set(a for a in 'ab')}")
468+
44 |+print(t"{set('abc') - set(a for a in 'ab')}")
469+
45 45 | print(t"{ set(a for a in 'abc') - set(a for a in 'ab') }")
470+
46 46 |
471+
47 47 |
472+
473+
C401.py:44:34: C401 [*] Unnecessary generator (rewrite using `set()`)
474+
|
475+
42 | small_nums = t"{set(a if a < 6 else 0 for a in range(3))}"
476+
43 | print(t"Hello {set(a for a in range(3))} World")
477+
44 | print(t"{set(a for a in 'abc') - set(a for a in 'ab')}")
478+
| ^^^^^^^^^^^^^^^^^^^^ C401
479+
45 | print(t"{ set(a for a in 'abc') - set(a for a in 'ab') }")
480+
|
481+
= help: Rewrite using `set()`
482+
483+
Unsafe fix
484+
41 41 | print(t"Hello { set(f(a) for a in 'abc') } World")
485+
42 42 | small_nums = t"{set(a if a < 6 else 0 for a in range(3))}"
486+
43 43 | print(t"Hello {set(a for a in range(3))} World")
487+
44 |-print(t"{set(a for a in 'abc') - set(a for a in 'ab')}")
488+
44 |+print(t"{set(a for a in 'abc') - set('ab')}")
489+
45 45 | print(t"{ set(a for a in 'abc') - set(a for a in 'ab') }")
490+
46 46 |
491+
47 47 |
492+
493+
C401.py:45:11: C401 [*] Unnecessary generator (rewrite using `set()`)
494+
|
495+
43 | print(t"Hello {set(a for a in range(3))} World")
496+
44 | print(t"{set(a for a in 'abc') - set(a for a in 'ab')}")
497+
45 | print(t"{ set(a for a in 'abc') - set(a for a in 'ab') }")
498+
| ^^^^^^^^^^^^^^^^^^^^^ C401
499+
|
500+
= help: Rewrite using `set()`
501+
502+
Unsafe fix
503+
42 42 | small_nums = t"{set(a if a < 6 else 0 for a in range(3))}"
504+
43 43 | print(t"Hello {set(a for a in range(3))} World")
505+
44 44 | print(t"{set(a for a in 'abc') - set(a for a in 'ab')}")
506+
45 |-print(t"{ set(a for a in 'abc') - set(a for a in 'ab') }")
507+
45 |+print(t"{ set('abc') - set(a for a in 'ab') }")
508+
46 46 |
509+
47 47 |
510+
48 48 | # Not built-in set.
511+
512+
C401.py:45:35: C401 [*] Unnecessary generator (rewrite using `set()`)
513+
|
514+
43 | print(t"Hello {set(a for a in range(3))} World")
515+
44 | print(t"{set(a for a in 'abc') - set(a for a in 'ab')}")
516+
45 | print(t"{ set(a for a in 'abc') - set(a for a in 'ab') }")
517+
| ^^^^^^^^^^^^^^^^^^^^ C401
518+
|
519+
= help: Rewrite using `set()`
520+
521+
Unsafe fix
522+
42 42 | small_nums = t"{set(a if a < 6 else 0 for a in range(3))}"
523+
43 43 | print(t"Hello {set(a for a in range(3))} World")
524+
44 44 | print(t"{set(a for a in 'abc') - set(a for a in 'ab')}")
525+
45 |-print(t"{ set(a for a in 'abc') - set(a for a in 'ab') }")
526+
45 |+print(t"{ set(a for a in 'abc') - set('ab') }")
527+
46 46 |
528+
47 47 |
529+
48 48 | # Not built-in set.

0 commit comments

Comments
 (0)