@@ -206,6 +206,97 @@ def test_tstrings(self):
206206 self .check_ast_roundtrip ("t'foo'" )
207207 self .check_ast_roundtrip ("t'foo {bar}'" )
208208 self .check_ast_roundtrip ("t'foo {bar!s:.2f}'" )
209+ self .check_ast_roundtrip ("t'{a + b}'" )
210+ self .check_ast_roundtrip ("t'{a + b:x}'" )
211+ self .check_ast_roundtrip ("t'{a + b!s}'" )
212+ self .check_ast_roundtrip ("t'{ {a}}'" )
213+ self .check_ast_roundtrip ("t'{ {a}=}'" )
214+ self .check_ast_roundtrip ("t'{{a}}'" )
215+ self .check_ast_roundtrip ("t''" )
216+ self .check_ast_roundtrip ('t""' )
217+ self .check_ast_roundtrip ("t'{(lambda x: x)}'" )
218+ self .check_ast_roundtrip ("t'{t'{x}'}'" )
219+
220+ def test_tstring_with_nonsensical_str_field (self ):
221+ # `value` suggests that the original code is `t'{test1}`, but `str` suggests otherwise
222+ self .assertEqual (
223+ ast .unparse (
224+ ast .TemplateStr (
225+ values = [
226+ ast .Interpolation (
227+ value = ast .Name (id = "test1" , ctx = ast .Load ()), str = "test2" , conversion = - 1
228+ )
229+ ]
230+ )
231+ ),
232+ "t'{test2}'" ,
233+ )
234+
235+ def test_tstring_with_none_str_field (self ):
236+ self .assertEqual (
237+ ast .unparse (
238+ ast .TemplateStr (
239+ [ast .Interpolation (value = ast .Name (id = "test1" ), str = None , conversion = - 1 )]
240+ )
241+ ),
242+ "t'{test1}'" ,
243+ )
244+ self .assertEqual (
245+ ast .unparse (
246+ ast .TemplateStr (
247+ [
248+ ast .Interpolation (
249+ value = ast .Lambda (
250+ args = ast .arguments (args = [ast .arg (arg = "x" )]),
251+ body = ast .Name (id = "x" ),
252+ ),
253+ str = None ,
254+ conversion = - 1 ,
255+ )
256+ ]
257+ )
258+ ),
259+ "t'{(lambda x: x)}'" ,
260+ )
261+ self .assertEqual (
262+ ast .unparse (
263+ ast .TemplateStr (
264+ values = [
265+ ast .Interpolation (
266+ value = ast .TemplateStr (
267+ # `str` field kept here
268+ [ast .Interpolation (value = ast .Name (id = "x" ), str = "y" , conversion = - 1 )]
269+ ),
270+ str = None ,
271+ conversion = - 1 ,
272+ )
273+ ]
274+ )
275+ ),
276+ '''t"{t'{y}'}"''' ,
277+ )
278+ self .assertEqual (
279+ ast .unparse (
280+ ast .TemplateStr (
281+ values = [
282+ ast .Interpolation (
283+ value = ast .TemplateStr (
284+ [ast .Interpolation (value = ast .Name (id = "x" ), str = None , conversion = - 1 )]
285+ ),
286+ str = None ,
287+ conversion = - 1 ,
288+ )
289+ ]
290+ )
291+ ),
292+ '''t"{t'{x}'}"''' ,
293+ )
294+ self .assertEqual (
295+ ast .unparse (ast .TemplateStr (
296+ [ast .Interpolation (value = ast .Constant (value = "foo" ), str = None , conversion = 114 )]
297+ )),
298+ '''t"{'foo'!r}"''' ,
299+ )
209300
210301 def test_strings (self ):
211302 self .check_ast_roundtrip ("u'foo'" )
@@ -813,15 +904,6 @@ def test_type_params(self):
813904 self .check_ast_roundtrip ("def f[T: int = int, **P = int, *Ts = *int]():\n pass" )
814905 self .check_ast_roundtrip ("class C[T: int = int, **P = int, *Ts = *int]():\n pass" )
815906
816- def test_tstr (self ):
817- self .check_ast_roundtrip ("t'{a + b}'" )
818- self .check_ast_roundtrip ("t'{a + b:x}'" )
819- self .check_ast_roundtrip ("t'{a + b!s}'" )
820- self .check_ast_roundtrip ("t'{ {a}}'" )
821- self .check_ast_roundtrip ("t'{ {a}=}'" )
822- self .check_ast_roundtrip ("t'{{a}}'" )
823- self .check_ast_roundtrip ("t''" )
824-
825907
826908class ManualASTCreationTestCase (unittest .TestCase ):
827909 """Test that AST nodes created without a type_params field unparse correctly."""
0 commit comments