diff --git a/rewrite/rewrite/java/support_types.py b/rewrite/rewrite/java/support_types.py index 685aad3a..54e27fc2 100644 --- a/rewrite/rewrite/java/support_types.py +++ b/rewrite/rewrite/java/support_types.py @@ -106,8 +106,8 @@ def with_comments(self, comments: List[Comment]) -> Space: _whitespace: Optional[str] @property - def whitespace(self) -> Optional[str]: - return self._whitespace + def whitespace(self) -> str: + return self._whitespace if self._whitespace is not None else "" def with_whitespace(self, whitespace: Optional[str]) -> Space: return self if whitespace is self._whitespace else replace(self, _whitespace=whitespace) @@ -127,6 +127,36 @@ def format_first_prefix(cls, trees: List[J2], prefix: Space) -> List[J2]: return formatted_trees return trees + @property + def indent(self) -> str: + """ + The indentation after the last newline of either the last comment's suffix + or the global whitespace if no comments exist. + """ + return self._get_whitespace_indent(self.last_whitespace) + + @property + def last_whitespace(self) -> str: + """ + The raw suffix from the last comment if it exists, otherwise the global + whitespace (or empty string if whitespace is None). + """ + if self._comments: + return self._comments[-1].suffix + return self._whitespace if self._whitespace is not None else "" + + @staticmethod + def _get_whitespace_indent(whitespace: Optional[str]) -> str: + """ + A helper method that extracts everything after the last newline character + in `whitespace`. If no newline is present, returns `whitespace` as-is. + If the last newline is at the end, returns an empty string. + """ + if not whitespace: + return "" + last_newline = whitespace.rfind('\n') + return whitespace if last_newline == -1 else whitespace[last_newline + 1:] + EMPTY: ClassVar[Space] SINGLE_SPACE: ClassVar[Space] diff --git a/rewrite/rewrite/python/format/auto_format.py b/rewrite/rewrite/python/format/auto_format.py index 9c115958..54543512 100644 --- a/rewrite/rewrite/python/format/auto_format.py +++ b/rewrite/rewrite/python/format/auto_format.py @@ -6,7 +6,8 @@ from .normalize_tabs_or_spaces import NormalizeTabsOrSpacesVisitor from .remove_trailing_whitespace_visitor import RemoveTrailingWhitespaceVisitor from .spaces_visitor import SpacesVisitor -from .. import TabsAndIndentsStyle, GeneralFormatStyle +from .tabs_and_indents_visitor import TabsAndIndentsVisitor +from .. import TabsAndIndentsStyle, GeneralFormatStyle, WrappingAndBracesStyle from ..style import BlankLinesStyle, SpacesStyle, IntelliJ from ..visitor import PythonVisitor from ... import Recipe, Tree, Cursor @@ -28,13 +29,22 @@ def visit(self, tree: Optional[Tree], p: P, parent: Optional[Cursor] = None) -> cu = tree if isinstance(tree, JavaSourceFile) else self._cursor.first_enclosing_or_throw(JavaSourceFile) tree = NormalizeFormatVisitor(self._stop_after).visit(tree, p, self._cursor.fork()) + tree = BlankLinesVisitor(cu.get_style(BlankLinesStyle) or IntelliJ.blank_lines(), self._stop_after).visit(tree, p, self._cursor.fork()) + tree = SpacesVisitor(cu.get_style(SpacesStyle) or IntelliJ.spaces(), self._stop_after).visit(tree, p, self._cursor.fork()) + tree = NormalizeTabsOrSpacesVisitor( cu.get_style(TabsAndIndentsStyle) or IntelliJ.tabs_and_indents(), self._stop_after ).visit(tree, p, self._cursor.fork()) + + tree = TabsAndIndentsVisitor(cu.get_style(TabsAndIndentsStyle) or IntelliJ.tabs_and_indents(), + self._stop_after).visit(tree, p, self._cursor.fork()) + tree = NormalizeLineBreaksVisitor(cu.get_style(GeneralFormatStyle) or GeneralFormatStyle(False), self._stop_after).visit(tree, p, self._cursor.fork()) + tree = RemoveTrailingWhitespaceVisitor(self._stop_after).visit(tree, self._cursor.fork()) + return tree diff --git a/rewrite/rewrite/python/format/tabs_and_indents_visitor.py b/rewrite/rewrite/python/format/tabs_and_indents_visitor.py new file mode 100644 index 00000000..b486a109 --- /dev/null +++ b/rewrite/rewrite/python/format/tabs_and_indents_visitor.py @@ -0,0 +1,398 @@ +from __future__ import annotations + +import sys +from enum import Enum, auto +from typing import TypeVar, Optional, Union, cast, List + +from rewrite import Tree, Cursor, list_map +from rewrite.java import J, Space, JRightPadded, JLeftPadded, JContainer, JavaSourceFile, Case, WhileLoop, \ + Block, If, Label, ArrayDimension, ClassDeclaration, Empty, \ + Binary, MethodInvocation, FieldAccess, Identifier, Lambda, TextComment, Comment, TrailingComma +from rewrite.python import PythonVisitor, TabsAndIndentsStyle, PySpace, PyContainer, PyRightPadded, DictLiteral, \ + CollectionLiteral, ForLoop +from rewrite.visitor import P, T + +J2 = TypeVar('J2', bound=J) + + +class TabsAndIndentsVisitor(PythonVisitor[P]): + + def __init__(self, style: TabsAndIndentsStyle, stop_after: Optional[Tree] = None): + self._stop_after = stop_after + self._style = style + self._stop = False + + def visit(self, tree: Optional[Tree], p: P, parent: Optional[Cursor] = None) -> Optional[J]: + if parent is not None: + self._cursor = parent + if tree is None: + return cast(Optional[J], self.default_value(None, p)) + + for c in parent.get_path_as_cursors(): + v = c.value + space = None + if isinstance(v, J): + space = v.prefix + elif isinstance(v, JRightPadded): + space = v.after + elif isinstance(v, JLeftPadded): + space = v.before + elif isinstance(v, JContainer): + space = v.before + + if space is not None and '\n' in space.last_whitespace: + indent = self.find_indent(space) + if indent != 0: + c.put_message("last_indent", indent) + + for next_parent in parent.get_path(): + if isinstance(next_parent, J): + self.pre_visit(next_parent, p) + break + + return super().visit(tree, p) + + def pre_visit(self, tree: T, p: P) -> Optional[T]: + if isinstance(tree, (JavaSourceFile, Label, ArrayDimension, ClassDeclaration)): + self.cursor.put_message("indent_type", self.IndentType.ALIGN) + elif isinstance(tree, + (Block, If, If.Else, ForLoop, WhileLoop, Case, DictLiteral, CollectionLiteral)): + # NOTE: Added CollectionLiteral, DictLiteral here + self.cursor.put_message("indent_type", self.IndentType.INDENT) + else: + self.cursor.put_message("indent_type", self.IndentType.CONTINUATION_INDENT) + + return tree + + def post_visit(self, tree: T, p: P) -> Optional[T]: + if self._stop_after and tree == self._stop_after: + self._stop = True + return tree + + def visit_space(self, space: Optional[Space], loc: Optional[Union[PySpace.Location, Space.Location]], + p: P) -> Space: + if space is None: + return space # type: ignore + + self._cursor.put_message("last_location", loc) + parent = self._cursor.parent + + if loc == Space.Location.METHOD_SELECT_SUFFIX: + chained_indent = cast(int, self.cursor.parent_tree_cursor().get_message("chained_indent", None)) + if chained_indent is not None: + self.cursor.parent_tree_cursor().put_message("last_indent", chained_indent) + return self._indent_to(space, chained_indent, loc) + + indent = cast(int, self.cursor.get_nearest_message("last_indent")) or 0 + indent_type = self.cursor.parent_or_throw.get_nearest_message("indent_type") or self.IndentType.ALIGN + + if not space.comments and '\n' not in space.last_whitespace or parent is None: + return space + + cursor_value = self._cursor.value + + # Block spaces are always aligned to their parent + # The second condition ensure init blocks are ignored. + # TODO: Second condition might be removed since it's not relevant for Python + align_block_prefix_to_parent = loc is Space.Location.BLOCK_PREFIX and '\n' in space.whitespace and \ + (isinstance(cursor_value, Block) and not isinstance( + self.cursor.parent_tree_cursor().value, Block)) + + align_block_to_parent = loc in ( + Space.Location.NEW_ARRAY_INITIALIZER_SUFFIX, + Space.Location.CATCH_PREFIX, + Space.Location.TRY_FINALLY, + Space.Location.ELSE_PREFIX, + ) + + if (loc == Space.Location.EXTENDS and "\n" in space.whitespace) or \ + Space.Location.EXTENDS == self.cursor.parent_or_throw.get_message("last_location", None): + indent_type = self.IndentType.CONTINUATION_INDENT + + if align_block_prefix_to_parent or align_block_to_parent: + indent_type = self.IndentType.ALIGN + + if indent_type == self.IndentType.INDENT: + indent += self._style.indent_size + elif indent_type == self.IndentType.CONTINUATION_INDENT: + indent += self._style.continuation_indent + + s: Space = self._indent_to(space, indent, loc) + if isinstance(cursor_value, J): + self.cursor.put_message("last_indent", indent) + elif loc == Space.Location.METHOD_SELECT_SUFFIX: + self.cursor.parent_tree_cursor().put_message("last_indent", indent) + + return s + + + def visit_right_padded(self, right: Optional[JRightPadded[T]], + loc: Union[PyRightPadded.Location, JRightPadded.Location], p: P) -> Optional[ + JRightPadded[T]]: + + if right is None: + return None + + self.cursor = Cursor(self._cursor, right) + + indent: int = cast(int, self.cursor.get_nearest_message("last_indent")) or 0 + + t: T = right.element + after = right.after + # TODO: Check if the visit_and_cast is really required here + + if isinstance(t, J): + elem = t + trailing_comma = right.markers.find_first(TrailingComma) + if '\n' in right.after.last_whitespace or '\n' in elem.prefix.last_whitespace: + if loc in (JRightPadded.Location.FOR_CONDITION, + JRightPadded.Location.FOR_UPDATE): + raise ValueError("This case should not be possible, should be safe for removal...") + elif loc in (JRightPadded.Location.METHOD_DECLARATION_PARAMETER, + JRightPadded.Location.RECORD_STATE_VECTOR): + if isinstance(elem, Empty): + elem = elem.with_prefix(self._indent_to(elem.prefix, indent, loc.after_location)) + after = right.after + else: + container: JContainer[J] = cast(JContainer[J], self.cursor.parent_or_throw.value) + elements: List[J] = container.elements + last_arg: J = elements[-1] + + # TODO: style.MethodDeclarationParameters doesn't exist for Python + # but should be self._style.method_declaration_parameters.align_when_multiple + elem = self.visit_and_cast(elem, J, p) + after = self._indent_to(right.after, + indent if t is last_arg else self._style.continuation_indent, + loc.after_location) + + elif loc == JRightPadded.Location.METHOD_INVOCATION_ARGUMENT: + elem, after = self._visit_method_invocation_argument_j_type(elem, right, indent, loc, p) + elif loc in (JRightPadded.Location.NEW_CLASS_ARGUMENTS, + JRightPadded.Location.ARRAY_INDEX, + JRightPadded.Location.PARENTHESES, + JRightPadded.Location.TYPE_PARAMETER): + elem = self.visit_and_cast(elem, J, p) + after = self._indent_to(right.after, indent, loc.after_location) + elif loc in (PyRightPadded.Location.COLLECTION_LITERAL_ELEMENT, PyRightPadded.Location.DICT_LITERAL_ELEMENT): + elem = self.visit_and_cast(elem, J, p) + args = cast(JContainer[J], self.cursor.parent_or_throw.value) + if not trailing_comma and args.padding.elements[-1] is right: + self.cursor.parent_or_throw.put_message("indent_type", self.IndentType.ALIGN) + after = self.visit_space(right.after, loc.after_location, p) + if trailing_comma: + self.cursor.parent_or_throw.put_message("indent_type", self.IndentType.ALIGN) + trailing_comma = trailing_comma.with_suffix(self.visit_space(trailing_comma.suffix, loc.after_location, p)) + right = right.with_markers(right.markers.compute_by_type(TrailingComma, lambda t: trailing_comma)) + elif loc == JRightPadded.Location.ANNOTATION_ARGUMENT: + raise NotImplementedError("Annotation argument not implemented") + else: + elem = self.visit_and_cast(elem, J, p) + after = self.visit_space(right.after, loc.after_location, p) + else: + if loc in (JRightPadded.Location.NEW_CLASS_ARGUMENTS, JRightPadded.Location.METHOD_INVOCATION_ARGUMENT): + any_other_arg_on_own_line = False + if "\n" not in elem.prefix.last_whitespace: + args = cast(JContainer[J], self.cursor.parent_or_throw.value) + for arg in args.padding.elements: + if arg == self.cursor.value: + continue + if "\n" in arg.element.prefix.last_whitespace: + any_other_arg_on_own_line = True + break + if not any_other_arg_on_own_line: + elem = self.visit_and_cast(elem, J, p) + after = self._indent_to(right.after, indent, loc.after_location) + + if not any_other_arg_on_own_line: + if not isinstance(elem, Binary): + if not isinstance(elem, MethodInvocation) or "\n" in elem.prefix.last_whitespace: + self.cursor.put_message("last_indent", indent + self._style.continuation_indent) + else: + method_invocation = elem + select = method_invocation.select + if isinstance(select, (FieldAccess, Identifier, MethodInvocation)): + self.cursor.put_message("last_indent", indent + self._style.continuation_indent) + + elem = self.visit_and_cast(elem, J, p) + after = self.visit_space(right.after, loc.after_location, p) + else: + elem = self.visit_and_cast(elem, J, p) + after = self.visit_space(right.after, loc.after_location, p) + + t = cast(T, elem) + else: + after = self.visit_space(right.after, loc.after_location, p) + + self.cursor = self.cursor.parent # type: ignore + return right.with_after(after).with_element(t) + + def visit_container(self, container: Optional[JContainer[J2]], + loc: Union[PyContainer.Location, JContainer.Location], p: P) -> JContainer[J2]: + if container is None: + return container # type: ignore + + self._cursor = Cursor(self._cursor, container) + + indent = cast(int, self.cursor.get_nearest_message("last_indent")) or 0 + if '\n' in container.before.last_whitespace: + if loc in (JContainer.Location.TYPE_PARAMETERS, + JContainer.Location.IMPLEMENTS, + JContainer.Location.THROWS, + JContainer.Location.NEW_CLASS_ARGUMENTS): + before = self._indent_to(container.before, indent + self._style.continuation_indent, + loc.before_location) + self.cursor.put_message("indent_type", self.IndentType.ALIGN) + self.cursor.put_message("last_indent", indent + self._style.continuation_indent) + else: + before = self.visit_space(container.before, loc.before_location, p) + js = list_map(lambda t: self.visit_right_padded(t, loc.element_location, p), container.padding.elements) + else: + if loc in (JContainer.Location.TYPE_PARAMETERS, + JContainer.Location.IMPLEMENTS, + JContainer.Location.THROWS, + JContainer.Location.NEW_CLASS_ARGUMENTS, + JContainer.Location.METHOD_INVOCATION_ARGUMENTS): + self.cursor.put_message("indent_type", self.IndentType.CONTINUATION_INDENT) + before = self.visit_space(container.before, loc.before_location, p) + else: + before = self.visit_space(container.before, loc.before_location, p) + js = list_map(lambda t: self.visit_right_padded(t, loc.element_location, p), container.padding.elements) + + self._cursor = self._cursor.parent # type: ignore + + if container.padding.elements is js and container.before is before: + return container + return JContainer(before, js, container.markers) + + def _indent_to(self, space: Space, column: int, space_location: Optional[Union[PySpace.Location, Space.Location]]) -> Space: + s = space + whitespace = s.whitespace + + if space_location == Space.Location.COMPILATION_UNIT_PREFIX and whitespace: + s = s.with_whitespace("") + elif not s.comments and "\n" not in s.last_whitespace: + return s + + if not s.comments: + indent = self.find_indent(s) + if indent != column: + shift = column - indent + s = s.with_whitespace(self._indent(whitespace, shift)) + else: + def whitespace_indent(text: str) -> str: + # TODO: Placeholder function, taken from java openrewrite.StringUtils + indent: List[str] = [] + for c in text: + if c == '\n' or c == '\r': + return ''.join(indent) + elif c.isspace(): + indent.append(c) + else: + return ''.join(indent) + return ''.join(indent) + + # TODO: This is the java version, however the python version is probably different + has_file_leading_comment = space.comments and ( + (space_location == Space.Location.COMPILATION_UNIT_PREFIX) or ( + space_location == Space.Location.BLOCK_END) or + (space_location == Space.Location.CLASS_DECLARATION_PREFIX and space.comments[0].multiline) + ) + + final_column = column + self._style.indent_size if space_location == Space.Location.BLOCK_END else column + last_indent: str = space.whitespace[space.whitespace.rfind('\n') + 1:] + indent = self._get_length_of_whitespace(whitespace_indent(last_indent)) + + if indent != final_column: + if (has_file_leading_comment or ("\n" in whitespace)) and ( + # Do not shift single-line comments at column 0. + not (s.comments and isinstance(s.comments[0], TextComment) and + not s.comments[0].multiline and self._get_length_of_whitespace(s.whitespace) == 0)): + shift = final_column - indent + s = s.with_whitespace(whitespace[:whitespace.rfind('\n') + 1] + self._indent(last_indent, shift)) + + final_space = s + last_comment_pos = len(s.comments) - 1 + + def _process_comment(i: int, c: Comment) -> Comment: + if isinstance(c, TextComment) and not c.multiline: + # Do not shift single line comments at col 0. + if i != last_comment_pos and self._get_length_of_whitespace(c.suffix) == 0: + return c + + prior_suffix = space.whitespace if i == 0 else final_space.comments[i - 1].suffix + + if space_location == Space.Location.BLOCK_END and i != len(final_space.comments) - 1: + to_column = column + self._style.indent_size + else: + to_column = column + + new_c = c + if "\n" in prior_suffix or has_file_leading_comment: + new_c = c + + if '\n' in new_c.suffix: + suffix_indent = self._get_length_of_whitespace(new_c.suffix) + shift = to_column - suffix_indent + new_c = new_c.with_suffix(self._indent(new_c.suffix, shift)) + + return new_c + + s = s.with_comments(list_map(lambda i, c: _process_comment(c, i), s.comments)) + return s + + def _indent(self, whitespace: str, shift: int) -> str: + return self._shift(whitespace, shift) + + def _shift(self, text: str, shift: int) -> str: + tab_indent = self._style.tab_size + if not self._style.use_tab_character: + tab_indent = sys.maxsize + + if shift > 0: + text += '\t' * (shift // tab_indent) + text += ' ' * (shift % tab_indent) + else: + if self._style.use_tab_character: + len_text = len(text) + (shift // tab_indent) + else: + len_text = len(text) + shift + if len_text >= 0: + text = text[:len_text] + + return text + + def find_indent(self, space: Space) -> int: + return self._get_length_of_whitespace(space.indent) + + def _get_length_of_whitespace(self, whitespace: Optional[str]) -> int: + if whitespace is None: + return 0 + length = 0 + for c in whitespace: + length += self._style.tab_size if c == '\t' else 1 + if c in ('\n', '\r'): + length = 0 + return length + + def _visit_method_invocation_argument_j_type(self, elem: J, right: JRightPadded[T], indent: int, loc: Union[PyRightPadded.Location, JRightPadded.Location], p: P) -> tuple[J, Space]: + if "\n" not in elem.prefix.last_whitespace and isinstance(elem, Lambda): + body = elem.body + if not isinstance(body, Binary): + if "\n" not in body.prefix.last_whitespace: + self.cursor.parent_or_throw.put_message("last_indent", indent + self._style.continuation_indent) + + elem = self.visit_and_cast(elem, J, p) + after = self._indent_to(right.after, indent, loc.after_location) + if after.comments or "\n" in after.last_whitespace: + parent = self.cursor.parent_tree_cursor() + grandparent = parent.parent_tree_cursor() + # propagate indentation up in the method chain hierarchy + if isinstance(grandparent.value, MethodInvocation) and grandparent.value.select == parent.value: + grandparent.put_message("last_indent", indent) + grandparent.put_message("chained_indent", indent) + return elem, after + + class IndentType(Enum): + ALIGN = auto() + INDENT = auto() + CONTINUATION_INDENT = auto() diff --git a/rewrite/rewrite/python/utils/__init__.py b/rewrite/rewrite/python/utils/__init__.py new file mode 100644 index 00000000..ba1a43ca --- /dev/null +++ b/rewrite/rewrite/python/utils/__init__.py @@ -0,0 +1,5 @@ +from .tree_visiting_printer import * + +__all__ = [ + "TreeVisitingPrinter" +] diff --git a/rewrite/rewrite/python/utils/tree_visiting_printer.py b/rewrite/rewrite/python/utils/tree_visiting_printer.py new file mode 100644 index 00000000..4b6a2d6f --- /dev/null +++ b/rewrite/rewrite/python/utils/tree_visiting_printer.py @@ -0,0 +1,124 @@ +from typing import Optional, Union + +from rewrite import Cursor +from rewrite import Tree +from rewrite.java import Space, Literal, Identifier, JRightPadded, JLeftPadded, Modifier +from rewrite.python import PythonVisitor, PySpace +from rewrite.visitor import T, P + + +class TreeVisitingPrinter(PythonVisitor): + INDENT = " " + ELEMENT_PREFIX = "\\---" + CONTINUE_PREFIX = "|---" + UNVISITED_PREFIX = "#" + BRANCH_CONTINUE_CHAR = '|' + BRANCH_END_CHAR = '\\' + CONTENT_MAX_LENGTH = 120 + + _last_cursor_stack = [] + _lines = [] + + def __init__(self, indent: str = " "): + super().__init__() + self.INDENT = indent + + def visit(self, tree: Optional[Tree], p: P, parent: Optional[Cursor] = None) -> Optional[T]: + if tree is None: + return super().visit(None, p, parent) # pyright: ignore [reportReturnType] + + _current_stack = list(self._cursor.get_path()) + _current_stack.reverse() + depth = len(_current_stack) + if not self._last_cursor_stack: + self._last_cursor_stack = _current_stack + [tree] + else: + diff_position = self.find_diff_pos(_current_stack, self._last_cursor_stack) + if diff_position >= 0: + for i in _current_stack[diff_position:]: + self._lines += [[depth, i]] + self._last_cursor_stack = self._last_cursor_stack[:diff_position] + + self._lines += [[depth, tree]] + self._last_cursor_stack = _current_stack + [tree] + return super().visit(tree, p, parent) # pyright: ignore [reportReturnType] + + def visit_space(self, space: Optional[Space], loc: Optional[Union[PySpace.Location, Space.Location]], + p: P) -> Space: + print("Loc", loc, "el,", self._print_element(self.cursor.parent.value)) + return super().visit_space(space, loc, p) + + def _print_tree(self) -> str: + output = "" + offset = 0 + for idx, (depth, element) in enumerate(self._lines): + offset = depth if idx == 0 else offset + padding = self.INDENT * (depth - offset) + if idx + 1 < len(self._lines) and self._lines[idx + 1][0] <= depth or idx + 1 == len(self._lines): + output += padding + self.CONTINUE_PREFIX + self._print_element(element) + "\n" + else: + output += padding + self.ELEMENT_PREFIX + self._print_element(element) + "\n" + return output + + def _print_element(self, element) -> str: + type_name = type(element).__name__ + line = [] + + if hasattr(element, "before"): + line.append(f"before= {self._print_space(element.before)}") + + if hasattr(element, "after"): + line.append(f"after= {self._print_space(element.after)}") + + if hasattr(element, "suffix"): + line.append(f"suffix= {self._print_space(element.suffix)}") + + if hasattr(element, "prefix"): + line.append(f"prefix= {self._print_space(element.prefix)}") + + if isinstance(element, Identifier): + type_name = f'{type_name} | "{element.simple_name}"' + + if isinstance(element, Literal): + type_name = f'{type_name} | {element.value_source}' + + if isinstance(element, JRightPadded): + return f'{type_name} | after= {self._print_space(element.after)}' + + if isinstance(element, JLeftPadded): + return f'{type_name} | before= {self._print_space(element.before)}' + + if isinstance(element, Modifier): + return type_name + ( + (" | " + element.type.name) if hasattr(element, "type") else "") + + if line: + return type_name + " | " + " | ".join(line) + return type_name + + @staticmethod + def _print_space(space: Space) -> str: + parts = [] + if space.whitespace: + parts.append(f'whitespace="{repr(space.whitespace)}"') + if space.comments: + parts.append(f'comments="{space.comments}"') + return " ".join(parts).replace("\n", "\\s\n") + + @staticmethod + def print_tree_all(tree: "Tree") -> str: + visitor = TreeVisitingPrinter() + visitor.visit(tree, None, None) + print(visitor._print_tree()) + return "" + + def find_diff_pos(self, cursor_stack, last_cursor_stack): + diff_pos = -1 + for i in range(len(cursor_stack)): + if i >= len(last_cursor_stack): + diff_pos = i + break + if cursor_stack[i] != last_cursor_stack[i]: + diff_pos = i + break + return diff_pos diff --git a/rewrite/rewrite/visitor.py b/rewrite/rewrite/visitor.py index 9d69274e..85b6dcc3 100644 --- a/rewrite/rewrite/visitor.py +++ b/rewrite/rewrite/visitor.py @@ -2,7 +2,7 @@ from abc import ABC from dataclasses import dataclass -from typing import TypeVar, Optional, Dict, List, Any, cast, Type, ClassVar, Generic +from typing import TypeVar, Optional, Dict, List, Any, cast, Type, ClassVar, Generic, Generator from .execution import RecipeRunException from .markers import Marker, Markers @@ -32,13 +32,13 @@ def put_message(self, key: str, value: object) -> None: object.__setattr__(self, 'messages', {}) self.messages[key] = value # type: ignore - def parent_tree_cursor(self) -> Optional[Cursor]: + def parent_tree_cursor(self) -> Cursor: c = self.parent while c is not None: - if isinstance(c.value, Tree): + if isinstance(c.value, Tree) or c.value == Cursor.ROOT_VALUE: return c c = c.parent - return None + raise ValueError("Expected to find parent tree cursor for " + str(self)) def first_enclosing_or_throw(self, type: Type[P]) -> P: result = self.first_enclosing(type) @@ -57,6 +57,30 @@ def first_enclosing(self, type_: Type[P]) -> Optional[P]: def fork(self) -> Cursor: return Cursor(self.parent.fork(), self.value) if self.parent else self + def get_path(self) -> Generator[Any]: + c = self + while c is not None: + yield c.value + c = c.parent + + def get_path_as_cursors(self) -> Generator[Cursor]: + c = self + while c is not None: + yield c + c = c.parent + + def get_nearest_message(self, key: str) -> Optional[object]: + for c in self.get_path_as_cursors(): + if c.messages is not None and key in c.messages: + return c.messages[key] + return None + + @property + def parent_or_throw(self) -> Cursor: + if self.parent is None: + raise ValueError("Cursor is expected to have a parent:", self) + return self.parent + class TreeVisitor(ABC, Generic[T, P]): _visit_count: int = 0 diff --git a/rewrite/tests/python/all/format/tabs_and_indents_visitor_test.py b/rewrite/tests/python/all/format/tabs_and_indents_visitor_test.py new file mode 100644 index 00000000..6b8e78b3 --- /dev/null +++ b/rewrite/tests/python/all/format/tabs_and_indents_visitor_test.py @@ -0,0 +1,751 @@ +import pytest + +from rewrite.python import IntelliJ +from rewrite.python.format import TabsAndIndentsVisitor +from rewrite.test import rewrite_run, python, RecipeSpec, from_visitor + + +def test_multi_assignment(): + style = IntelliJ.tabs_and_indents().with_use_tab_character(False).with_tab_size(4).with_indent_size(4) + rewrite_run( + # language=python + python( + """ + def assign_values(): + a, b = 1, 2 + x, y, z = 3, 4, 5 + return a, b, x, y, z + """, + """ + def assign_values(): + a, b = 1, 2 + x, y, z = 3, 4, 5 + return a, b, x, y, z + """ + ), + spec=RecipeSpec().with_recipes(from_visitor(TabsAndIndentsVisitor(style))) + # spec = RecipeSpec().with_recipes(from_visitor(AutoFormatVisitor())) + ) + + +def test_if_else_statement(): + style = IntelliJ.tabs_and_indents().with_use_tab_character(False).with_tab_size(4).with_indent_size(4) + rewrite_run( + # language=python + python( + """ + def check_value(x): + if x > 0: + return "Positive" + else: + return "Non-positive" + """, + """ + def check_value(x): + if x > 0: + return "Positive" + else: + return "Non-positive" + """ + ), + spec=RecipeSpec().with_recipes(from_visitor(TabsAndIndentsVisitor(style))) + ) + + +def test_if_else_statement_no_else_with_extra_statements(): + style = IntelliJ.tabs_and_indents().with_use_tab_character(False).with_tab_size(4).with_indent_size(4) + rewrite_run( + # language=python + python( + """ + def check_value(x): + if x > 0: + a = 1 + x + return "Positive" + a = -1 + x + return "Non-positive" + """, + """ + def check_value(x): + if x > 0: + a = 1 + x + return "Positive" + a = -1 + x + return "Non-positive" + """ + ), + spec=RecipeSpec().with_recipes(from_visitor(TabsAndIndentsVisitor(style))) + ) + + +def test_if_else_statement_no_else_multi_return_values(): + style = IntelliJ.tabs_and_indents().with_use_tab_character(False).with_tab_size(4).with_indent_size(4) + rewrite_run( + # language=python + python( + """\ + def check_value(x): + if x > 0: + a = 1 + x + return a, "Positive" + return a + """, + """\ + def check_value(x): + if x > 0: + a = 1 + x + return a, "Positive" + return a + """ + ), + # spec=RecipeSpec().with_recipes(from_visitor(AutoFormatVisitor())) + spec=RecipeSpec().with_recipes(from_visitor(TabsAndIndentsVisitor(style))) + ) + + +def test_if_else_statement_no_else_multi_return_values_as_tuple(): + style = IntelliJ.tabs_and_indents().with_use_tab_character(False).with_tab_size(4).with_indent_size(4) + rewrite_run( + # language=python + python( + """\ + def check_value(x): + if x > 0: + a = 1 + x + return (a, "Positive") + return a + """, + """\ + def check_value(x): + if x > 0: + a = 1 + x + return (a, "Positive") + return a + """ + ), + spec=RecipeSpec().with_recipes(from_visitor(TabsAndIndentsVisitor(style))) + ) + + +def test_if_elif_else_statement(): + style = IntelliJ.tabs_and_indents().with_use_tab_character(False).with_tab_size(4).with_indent_size(4) + rewrite_run( + # language=python + python( + """ + def check_value(x): + if x > 0: + return "Positive" + elif x < 0: + return "Negative" + else: + return "Null" + """, + """ + def check_value(x): + if x > 0: + return "Positive" + elif x < 0: + return "Negative" + else: + return "Null" + """ + ), + spec=RecipeSpec().with_recipes(from_visitor(TabsAndIndentsVisitor(style))) + ) + + +def test_for_statement(): + style = IntelliJ.tabs_and_indents().with_use_tab_character(False).with_tab_size(4).with_indent_size(4) + rewrite_run( + # language=python + python( + """ + def sum_list(lst): + total = 0 + for num in lst: + total += num + return total + """, + """ + def sum_list(lst): + total = 0 + for num in lst: + total += num + return total + """ + ), + spec=RecipeSpec().with_recipes(from_visitor(TabsAndIndentsVisitor(style))) + ) + + +def test_for_statement_with_list_comprehension(): + style = IntelliJ.tabs_and_indents().with_use_tab_character(False).with_tab_size(4).with_indent_size(4) + rewrite_run( + # language=python + python( + """ + def even_numbers(lst): + return [x for x in lst if x % 2 == 0] + """, + """ + def even_numbers(lst): + return [x for x in lst if x % 2 == 0] + """ + ), + spec=RecipeSpec().with_recipes(from_visitor(TabsAndIndentsVisitor(style))) + ) + + +def test_for_statement_with_list_comprehension_multiline(): + style = IntelliJ.tabs_and_indents().with_use_tab_character(False).with_tab_size(4).with_indent_size(4) + rewrite_run( + # language=python + python( + """ + def even_numbers(lst): + return [x for x + in lst if x % 2 == 0] + """, + """ + def even_numbers(lst): + return [x for x + in lst if x % 2 == 0] + """ + ), + spec=RecipeSpec().with_recipes(from_visitor(TabsAndIndentsVisitor(style))) + ) + + +def test_while_statement(): + style = IntelliJ.tabs_and_indents().with_use_tab_character(False).with_tab_size(4).with_indent_size(4) + rewrite_run( + # language=python + python( + """ + def countdown(n): + while n > 0: + print(n) + n -= 1 + """, + """ + def countdown(n): + while n > 0: + print(n) + n -= 1 + """ + ), + spec=RecipeSpec().with_recipes(from_visitor(TabsAndIndentsVisitor(style))) + ) + + +def test_class_statement(): + style = IntelliJ.tabs_and_indents().with_use_tab_character(False).with_tab_size(4).with_indent_size(4) + rewrite_run( + # language=python + python( + """ + class MyClass: + def __init__(self, value): + self.value = value + def get_value(self): + return self.value + """, + """ + class MyClass: + def __init__(self, value): + self.value = value + def get_value(self): + return self.value + """ + ), + spec=RecipeSpec().with_recipes(from_visitor(TabsAndIndentsVisitor(style))) + ) + + +def test_with_statement(): + style = IntelliJ.tabs_and_indents().with_use_tab_character(False).with_tab_size(4).with_indent_size(4) + rewrite_run( + # language=python + python( + """ + def read_file(file_path): + with open(file_path, 'r') as file: + content = file.read() + return content + """, + """ + def read_file(file_path): + with open(file_path, 'r') as file: + content = file.read() + return content + """ + ), + spec=RecipeSpec().with_recipes(from_visitor(TabsAndIndentsVisitor(style))) + ) + + +def test_try_statement_basic(): + style = IntelliJ.tabs_and_indents().with_use_tab_character(False).with_tab_size(4).with_indent_size(4) + rewrite_run( + # language=python + python( + """ + def divide(a, b): + try: + c = a / b + except ZeroDivisionError: + return None + return c + """, + """ + def divide(a, b): + try: + c = a / b + except ZeroDivisionError: + return None + return c + """ + ), + spec=RecipeSpec().with_recipes(from_visitor(TabsAndIndentsVisitor(style))) + ) + + +def test_try_statement_with_multi_return(): + style = IntelliJ.tabs_and_indents().with_use_tab_character(False).with_tab_size(4).with_indent_size(4) + rewrite_run( + # language=python + python( + """ + def divide(a, b): + try: + c = a / b + if c > 42: return a, c + except ZeroDivisionError: + return None + return a, b + """, + """ + def divide(a, b): + try: + c = a / b + if c > 42: return a, c + except ZeroDivisionError: + return None + return a, b + """ + ), + spec=RecipeSpec().with_recipes(from_visitor(TabsAndIndentsVisitor(style))) + ) + + +def test_basic_indent_modification(): + style = IntelliJ.tabs_and_indents().with_use_tab_character(False).with_tab_size(2).with_indent_size(4) + rewrite_run( + # language=python + python( + ''' + def my_function(): + return None + ''', + ''' + def my_function(): + return None + ''' + ), + spec=RecipeSpec().with_recipes(from_visitor(TabsAndIndentsVisitor(style))) + ) + + +def test_multiline_list(): + style = IntelliJ.tabs_and_indents().with_use_tab_character(False).with_tab_size(4).with_indent_size( + 4).with_continuation_indent(8) + rewrite_run( + # language=python + python( + """\ + my_list = [ #cool + 1, + 2, + 3, + 4 + ] + """, + """\ + my_list = [ #cool + 1, + 2, + 3, + 4 + ] + """ + ), + spec=RecipeSpec().with_recipes(from_visitor(TabsAndIndentsVisitor(style))) + ) + + +def test_multiline_call_with_positional_args_no_align_multiline(): + style = IntelliJ.tabs_and_indents().with_use_tab_character(False).with_tab_size(4) + # noinspection PyInconsistentIndentation + rewrite_run( + # language=python + python( + """ + def long_function_name(var_one, var_two, + var_three, + var_four): + print(var_one) + """, + """ + def long_function_name(var_one, var_two, + var_three, + var_four): + print(var_one) + """ + ), + spec=RecipeSpec() + .with_recipes( + from_visitor(TabsAndIndentsVisitor(style)) + ) + ) + + +def test_multiline_call_with_positional_args_and_no_arg_first_line(): + style = IntelliJ.tabs_and_indents().with_use_tab_character(False).with_tab_size(4) + # noinspection PyInconsistentIndentation + rewrite_run( + # language=python + python( + """ + def long_function_name( + var_one, + var_two, var_three, + var_four): + print(var_one) + """, + """ + def long_function_name( + var_one, + var_two, var_three, + var_four): + print(var_one) + """ + ), + spec=RecipeSpec() + .with_recipes( + from_visitor(TabsAndIndentsVisitor(style)) + ) + ) + + +def test_multiline_call_with_args_without_multiline_align(): + style = IntelliJ.tabs_and_indents().with_use_tab_character(False).with_tab_size(4) + rewrite_run( + # language=python + python( + """ + result = long_function_name(10, 'foo', + another_arg=42, + final_arg="bar") + """, + """ + result = long_function_name(10, 'foo', + another_arg=42, + final_arg="bar") + """ + ), + spec=RecipeSpec().with_recipes(from_visitor(TabsAndIndentsVisitor(style))) + ) + + +def test_multiline_list_inside_function(): + style = IntelliJ.tabs_and_indents().with_use_tab_character(False).with_tab_size(4).with_indent_size(4) + rewrite_run( + # language=python + python( + """ + def create_list(): + my_list = [ + 1, + 2, + 3, + 4 + ] + return my_list + """, + """ + def create_list(): + my_list = [ + 1, + 2, + 3, + 4 + ] + return my_list + """ + ), + spec=RecipeSpec().with_recipes(from_visitor(TabsAndIndentsVisitor(style))) + ) + +def test_multiline_list_inside_function_with_trailing_comma(): + style = IntelliJ.tabs_and_indents().with_use_tab_character(False).with_tab_size(4).with_indent_size(4) + rewrite_run( + # language=python + python( + """ + def create_list(): + my_list = [ + 1, + 2, + 3, + 4, + ] + return my_list + """, + """ + def create_list(): + my_list = [ + 1, + 2, + 3, + 4, + ] + return my_list + """ + ), + spec=RecipeSpec().with_recipes(from_visitor(TabsAndIndentsVisitor(style))) + ) + + +def test_basic_dictionary(): + style = IntelliJ.tabs_and_indents().with_use_tab_character(False).with_tab_size(4).with_indent_size(4) + rewrite_run( + # language=python + python( + """ + config = { + "key1": "value1", + "key2": "value2" + } + """, + """ + config = { + "key1": "value1", + "key2": "value2" + } + """ + ), + spec=RecipeSpec().with_recipes(from_visitor(TabsAndIndentsVisitor(style))) + ) + + +def test_nested_dictionary(): + style = IntelliJ.tabs_and_indents().with_use_tab_character(False).with_tab_size(4) + rewrite_run( + # language=python + python( + """ + config = { + "section": { + "key1": "value1", + "key2": [10, 20, + 30] + }, + "another_section": {"nested_key": "val"} + } + """, + """ + config = { + "section": { + "key1": "value1", + "key2": [10, 20, + 30] + }, + "another_section": {"nested_key": "val"} + } + """ + ), + spec=RecipeSpec().with_recipes(from_visitor(TabsAndIndentsVisitor(style))) + ) + + +def test_nested_dictionary_with_trailing_commas(): + style = IntelliJ.tabs_and_indents().with_use_tab_character(False).with_tab_size(4) + rewrite_run( + # language=python + python( + """ + config = { + "section": { + "key1": "value1", + "key2": [10, 20, + 30], + }, + "another_section": {"nested_key": "val"} + } + """, + """ + config = { + "section": { + "key1": "value1", + "key2": [10, 20, + 30], + }, + "another_section": {"nested_key": "val"} + } + """ + ), + spec=RecipeSpec().with_recipes(from_visitor(TabsAndIndentsVisitor(style))) + ) + + +def test_list_comprehension(): + style = IntelliJ.tabs_and_indents().with_use_tab_character(False).with_tab_size(4) + rewrite_run( + # language=python + python( + """ + def even_numbers(n): + return [x for x in range(n) + if x % 2 == 0] + """, + """ + def even_numbers(n): + return [x for x in range(n) + if x % 2 == 0] + """ + ), + spec=RecipeSpec().with_recipes(from_visitor(TabsAndIndentsVisitor(style))) + ) + + +def test_comment_alignment(): + style = IntelliJ.tabs_and_indents().with_use_tab_character(False).with_tab_size(4) + rewrite_run( + # language=python + python( + ''' + # Informative comment 1 + def my_function(a, b): + # Informative comment 2 + + # Informative comment 3 + if a > b: + # cool + a = b + 1 + return None # Informative comment 4 + # Informative comment 5 + ''', + ''' + # Informative comment 1 + def my_function(a, b): + # Informative comment 2 + + # Informative comment 3 + if a > b: + # cool + a = b + 1 + return None # Informative comment 4 + # Informative comment 5 + ''' + ), + spec=RecipeSpec().with_recipes(from_visitor(TabsAndIndentsVisitor(style))) + ) + return None + + +def test_comment_alignment_if_and_return(): + style = IntelliJ.tabs_and_indents().with_use_tab_character(False).with_tab_size(4) + rewrite_run( + # language=python + python( + ''' + def my_function(a, b): + if a > b: + # cool + a = b + 1 + # cool + return None + ''', + ''' + def my_function(a, b): + if a > b: + # cool + a = b + 1 + # cool + return None + ''' + ), + spec=RecipeSpec().with_recipes(from_visitor(TabsAndIndentsVisitor(style))) + ) + return None + + +def test_docstring_alignment(): + style = IntelliJ.tabs_and_indents().with_use_tab_character(False).with_tab_size(4) + rewrite_run( + # language=python + python( + ''' + def my_function(): + """ + This is a docstring that + should align with the function body. + """ + return None + ''', + ''' + def my_function(): + """ + This is a docstring that + should align with the function body. + """ + return None + ''' + ), + spec=RecipeSpec().with_recipes(from_visitor(TabsAndIndentsVisitor(style))) + ) + + +def test_method_select_suffix_already_correct(): + style = IntelliJ.tabs_and_indents().with_use_tab_character(False).with_tab_size(4) + rewrite_run( + # language=python + python( + """ + x = ("foo".startswith("f")) + """), + spec=RecipeSpec().with_recipes(from_visitor(TabsAndIndentsVisitor(style))) + ) + + +@pytest.mark.xfail +def test_method_select_suffix_new_line_already_correct(): + style = IntelliJ.tabs_and_indents().with_use_tab_character(False).with_tab_size(4) + rewrite_run( + # language=python + python( + """ + x = ("foo" + .startswith("f")) + """), + spec=RecipeSpec().with_recipes(from_visitor(TabsAndIndentsVisitor(style))) + ) + + +@pytest.mark.xfail +def test_method_select_suffix(): + style = IntelliJ.tabs_and_indents().with_use_tab_character(False).with_tab_size(4) + rewrite_run( + # language=python + python( + """ + x = ("foo" + .startswith("f")) + """, + """ + x = ("foo" + .startswith("f")) + """), + spec=RecipeSpec().with_recipes(from_visitor(TabsAndIndentsVisitor(style))) + )