Skip to content
63 changes: 39 additions & 24 deletions Lib/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,7 @@ def __init__(self):
self._type_ignores = {}
self._indent = 0
self._in_try_star = False
self._in_interactive = False

def interleave(self, inter, f, seq):
"""Call f on each item in seq, calling inter() in between."""
Expand Down Expand Up @@ -702,11 +703,20 @@ def maybe_newline(self):
if self._source:
self.write("\n")

def fill(self, text=""):
def maybe_semicolon(self):
"""Adds a "; " delimiter if it isn't the start of generated source"""
if self._source:
self.write("; ")

def fill(self, text="", allow_semi=True):
"""Indent a piece of text and append it, according to the current
indentation level"""
self.maybe_newline()
self.write(" " * self._indent + text)
indentation level, or only delineate with semicolon if applicable"""
if self._in_interactive and not self._indent and allow_semi:
self.maybe_semicolon()
self.write(text)
else:
self.maybe_newline()
self.write(" " * self._indent + text)

def write(self, *text):
"""Add new source parts"""
Expand Down Expand Up @@ -815,6 +825,11 @@ def visit_Module(self, node):
self._write_docstring_and_traverse_body(node)
self._type_ignores.clear()

def visit_Interactive(self, node):
self._in_interactive = True
self._write_docstring_and_traverse_body(node)
self._in_interactive = False

def visit_FunctionType(self, node):
with self.delimit("(", ")"):
self.interleave(
Expand Down Expand Up @@ -945,17 +960,17 @@ def visit_Raise(self, node):
self.traverse(node.cause)

def do_visit_try(self, node):
self.fill("try")
self.fill("try", allow_semi=False)
with self.block():
self.traverse(node.body)
for ex in node.handlers:
self.traverse(ex)
if node.orelse:
self.fill("else")
self.fill("else", allow_semi=False)
with self.block():
self.traverse(node.orelse)
if node.finalbody:
self.fill("finally")
self.fill("finally", allow_semi=False)
with self.block():
self.traverse(node.finalbody)

Expand All @@ -976,7 +991,7 @@ def visit_TryStar(self, node):
self._in_try_star = prev_in_try_star

def visit_ExceptHandler(self, node):
self.fill("except*" if self._in_try_star else "except")
self.fill("except*" if self._in_try_star else "except", allow_semi=False)
if node.type:
self.write(" ")
self.traverse(node.type)
Expand All @@ -989,9 +1004,9 @@ def visit_ExceptHandler(self, node):
def visit_ClassDef(self, node):
self.maybe_newline()
for deco in node.decorator_list:
self.fill("@")
self.fill("@", allow_semi=False)
self.traverse(deco)
self.fill("class " + node.name)
self.fill("class " + node.name, allow_semi=False)
if hasattr(node, "type_params"):
self._type_params_helper(node.type_params)
with self.delimit_if("(", ")", condition = node.bases or node.keywords):
Expand Down Expand Up @@ -1021,10 +1036,10 @@ def visit_AsyncFunctionDef(self, node):
def _function_helper(self, node, fill_suffix):
self.maybe_newline()
for deco in node.decorator_list:
self.fill("@")
self.fill("@", allow_semi=False)
self.traverse(deco)
def_str = fill_suffix + " " + node.name
self.fill(def_str)
self.fill(def_str, allow_semi=False)
if hasattr(node, "type_params"):
self._type_params_helper(node.type_params)
with self.delimit("(", ")"):
Expand Down Expand Up @@ -1075,54 +1090,54 @@ def visit_AsyncFor(self, node):
self._for_helper("async for ", node)

def _for_helper(self, fill, node):
self.fill(fill)
self.fill(fill, allow_semi=False)
self.set_precedence(_Precedence.TUPLE, node.target)
self.traverse(node.target)
self.write(" in ")
self.traverse(node.iter)
with self.block(extra=self.get_type_comment(node)):
self.traverse(node.body)
if node.orelse:
self.fill("else")
self.fill("else", allow_semi=False)
with self.block():
self.traverse(node.orelse)

def visit_If(self, node):
self.fill("if ")
self.fill("if ", allow_semi=False)
self.traverse(node.test)
with self.block():
self.traverse(node.body)
# collapse nested ifs into equivalent elifs.
while node.orelse and len(node.orelse) == 1 and isinstance(node.orelse[0], If):
node = node.orelse[0]
self.fill("elif ")
self.fill("elif ", allow_semi=False)
self.traverse(node.test)
with self.block():
self.traverse(node.body)
# final else
if node.orelse:
self.fill("else")
self.fill("else", allow_semi=False)
with self.block():
self.traverse(node.orelse)

def visit_While(self, node):
self.fill("while ")
self.fill("while ", allow_semi=False)
self.traverse(node.test)
with self.block():
self.traverse(node.body)
if node.orelse:
self.fill("else")
self.fill("else", allow_semi=False)
with self.block():
self.traverse(node.orelse)

def visit_With(self, node):
self.fill("with ")
self.fill("with ", allow_semi=False)
self.interleave(lambda: self.write(", "), self.traverse, node.items)
with self.block(extra=self.get_type_comment(node)):
self.traverse(node.body)

def visit_AsyncWith(self, node):
self.fill("async with ")
self.fill("async with ", allow_semi=False)
self.interleave(lambda: self.write(", "), self.traverse, node.items)
with self.block(extra=self.get_type_comment(node)):
self.traverse(node.body)
Expand Down Expand Up @@ -1264,7 +1279,7 @@ def visit_Name(self, node):
self.write(node.id)

def _write_docstring(self, node):
self.fill()
self.fill(allow_semi=False)
if node.kind == "u":
self.write("u")
self._write_str_avoiding_backslashes(node.value, quote_types=_MULTI_QUOTES)
Expand Down Expand Up @@ -1558,7 +1573,7 @@ def visit_Slice(self, node):
self.traverse(node.step)

def visit_Match(self, node):
self.fill("match ")
self.fill("match ", allow_semi=False)
self.traverse(node.subject)
with self.block():
for case in node.cases:
Expand Down Expand Up @@ -1652,7 +1667,7 @@ def visit_withitem(self, node):
self.traverse(node.optional_vars)

def visit_match_case(self, node):
self.fill("case ")
self.fill("case ", allow_semi=False)
self.traverse(node.pattern)
if node.guard:
self.write(" if ")
Expand Down
11 changes: 11 additions & 0 deletions Lib/test/test_ast/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,17 @@ def test_repr_large_input_crash(self):
r"Exceeds the limit \(\d+ digits\)"):
repr(ast.Constant(value=eval(source)))

def test_unparse_interactive(self):
# gh-129598: Fix of ast.unparse() when ast.Interactive contains multiple statements
source = "i = 1; 'expr'; raise Exception"
self.assertEqual(source, ast.unparse(ast.parse(source, mode='single')))
source = "if i:\n 'expr'\nelse:\n raise Exception"
unparsed = "if i:\n 'expr'\nelse:\n raise Exception"
self.assertEqual(unparsed, ast.unparse(ast.parse(source, mode='single')))
source = "@decorator\ndef func():\n 'docstring'\n i = 1; 'expr'; raise Exception"
unparsed = '''@decorator\ndef func():\n """docstring"""\n i = 1\n 'expr'\n raise Exception'''
self.assertEqual(unparsed, ast.unparse(ast.parse(source, mode='single')))


class CopyTests(unittest.TestCase):
"""Test copying and pickling AST nodes."""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix of :func:`ast.unparse` when :class:`ast.Interactive` contains multiple statements.
Loading