diff --git a/Doc/library/ast.rst b/Doc/library/ast.rst index 319b2c81505f48..37089fd22b9774 100644 --- a/Doc/library/ast.rst +++ b/Doc/library/ast.rst @@ -363,6 +363,11 @@ Literals function call). This has the same meaning as ``FormattedValue.value``. * ``str`` is a constant containing the text of the interpolation expression. + + If ``str`` is set to ``None``, then ``value`` is used to generate code + when calling :func:`ast.unparse`. This no longer guarantees that the + generated code is identical to the original and is intended for code + generation. * ``conversion`` is an integer: * -1: no conversion diff --git a/Lib/_ast_unparse.py b/Lib/_ast_unparse.py index 16cf56f62cc1e5..1c8741b5a55483 100644 --- a/Lib/_ast_unparse.py +++ b/Lib/_ast_unparse.py @@ -658,9 +658,9 @@ def _unparse_interpolation_value(self, inner): unparser.set_precedence(_Precedence.TEST.next(), inner) return unparser.visit(inner) - def _write_interpolation(self, node, is_interpolation=False): + def _write_interpolation(self, node, use_str_attr=False): with self.delimit("{", "}"): - if is_interpolation: + if use_str_attr: expr = node.str else: expr = self._unparse_interpolation_value(node.value) @@ -678,7 +678,8 @@ def visit_FormattedValue(self, node): self._write_interpolation(node) def visit_Interpolation(self, node): - self._write_interpolation(node, is_interpolation=True) + # If `str` is set to `None`, use the `value` to generate the source code. + self._write_interpolation(node, use_str_attr=node.str is not None) def visit_Name(self, node): self.write(node.id) diff --git a/Lib/test/test_unparse.py b/Lib/test/test_unparse.py index 0d6b05bc660b76..35e4652a87b423 100644 --- a/Lib/test/test_unparse.py +++ b/Lib/test/test_unparse.py @@ -206,6 +206,97 @@ def test_tstrings(self): self.check_ast_roundtrip("t'foo'") self.check_ast_roundtrip("t'foo {bar}'") self.check_ast_roundtrip("t'foo {bar!s:.2f}'") + self.check_ast_roundtrip("t'{a + b}'") + self.check_ast_roundtrip("t'{a + b:x}'") + self.check_ast_roundtrip("t'{a + b!s}'") + self.check_ast_roundtrip("t'{ {a}}'") + self.check_ast_roundtrip("t'{ {a}=}'") + self.check_ast_roundtrip("t'{{a}}'") + self.check_ast_roundtrip("t''") + self.check_ast_roundtrip('t""') + self.check_ast_roundtrip("t'{(lambda x: x)}'") + self.check_ast_roundtrip("t'{t'{x}'}'") + + def test_tstring_with_nonsensical_str_field(self): + # `value` suggests that the original code is `t'{test1}`, but `str` suggests otherwise + self.assertEqual( + ast.unparse( + ast.TemplateStr( + values=[ + ast.Interpolation( + value=ast.Name(id="test1", ctx=ast.Load()), str="test2", conversion=-1 + ) + ] + ) + ), + "t'{test2}'", + ) + + def test_tstring_with_none_str_field(self): + self.assertEqual( + ast.unparse( + ast.TemplateStr( + [ast.Interpolation(value=ast.Name(id="test1"), str=None, conversion=-1)] + ) + ), + "t'{test1}'", + ) + self.assertEqual( + ast.unparse( + ast.TemplateStr( + [ + ast.Interpolation( + value=ast.Lambda( + args=ast.arguments(args=[ast.arg(arg="x")]), + body=ast.Name(id="x"), + ), + str=None, + conversion=-1, + ) + ] + ) + ), + "t'{(lambda x: x)}'", + ) + self.assertEqual( + ast.unparse( + ast.TemplateStr( + values=[ + ast.Interpolation( + value=ast.TemplateStr( + # `str` field kept here + [ast.Interpolation(value=ast.Name(id="x"), str="y", conversion=-1)] + ), + str=None, + conversion=-1, + ) + ] + ) + ), + '''t"{t'{y}'}"''', + ) + self.assertEqual( + ast.unparse( + ast.TemplateStr( + values=[ + ast.Interpolation( + value=ast.TemplateStr( + [ast.Interpolation(value=ast.Name(id="x"), str=None, conversion=-1)] + ), + str=None, + conversion=-1, + ) + ] + ) + ), + '''t"{t'{x}'}"''', + ) + self.assertEqual( + ast.unparse(ast.TemplateStr( + [ast.Interpolation(value=ast.Constant(value="foo"), str=None, conversion=114)] + )), + '''t"{'foo'!r}"''', + ) def test_strings(self): self.check_ast_roundtrip("u'foo'") @@ -813,15 +904,6 @@ def test_type_params(self): self.check_ast_roundtrip("def f[T: int = int, **P = int, *Ts = *int]():\n pass") self.check_ast_roundtrip("class C[T: int = int, **P = int, *Ts = *int]():\n pass") - def test_tstr(self): - self.check_ast_roundtrip("t'{a + b}'") - self.check_ast_roundtrip("t'{a + b:x}'") - self.check_ast_roundtrip("t'{a + b!s}'") - self.check_ast_roundtrip("t'{ {a}}'") - self.check_ast_roundtrip("t'{ {a}=}'") - self.check_ast_roundtrip("t'{{a}}'") - self.check_ast_roundtrip("t''") - class ManualASTCreationTestCase(unittest.TestCase): """Test that AST nodes created without a type_params field unparse correctly."""