diff --git a/rewrite/rewrite/python/format/auto_format.py b/rewrite/rewrite/python/format/auto_format.py index c4f68aed..9c115958 100644 --- a/rewrite/rewrite/python/format/auto_format.py +++ b/rewrite/rewrite/python/format/auto_format.py @@ -2,10 +2,11 @@ from .blank_lines import BlankLinesVisitor from .normalize_format import NormalizeFormatVisitor +from .normalize_line_breaks_visitor import NormalizeLineBreaksVisitor +from .normalize_tabs_or_spaces import NormalizeTabsOrSpacesVisitor from .remove_trailing_whitespace_visitor import RemoveTrailingWhitespaceVisitor from .spaces_visitor import SpacesVisitor -from .normalize_tabs_or_spaces import NormalizeTabsOrSpacesVisitor -from .. import TabsAndIndentsStyle +from .. import TabsAndIndentsStyle, GeneralFormatStyle from ..style import BlankLinesStyle, SpacesStyle, IntelliJ from ..visitor import PythonVisitor from ... import Recipe, Tree, Cursor @@ -33,5 +34,7 @@ def visit(self, tree: Optional[Tree], p: P, parent: Optional[Cursor] = None) -> cu.get_style(TabsAndIndentsStyle) or IntelliJ.tabs_and_indents(), self._stop_after ).visit(tree, p, self._cursor.fork()) + tree = NormalizeLineBreaksVisitor(cu.get_style(GeneralFormatStyle) or GeneralFormatStyle(False), + self._stop_after).visit(tree, p, self._cursor.fork()) tree = RemoveTrailingWhitespaceVisitor(self._stop_after).visit(tree, self._cursor.fork()) return tree diff --git a/rewrite/rewrite/python/format/normalize_line_breaks_visitor.py b/rewrite/rewrite/python/format/normalize_line_breaks_visitor.py new file mode 100644 index 00000000..525dd9e9 --- /dev/null +++ b/rewrite/rewrite/python/format/normalize_line_breaks_visitor.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from typing import Optional, TypeVar, Union + +from rewrite import Tree, P, Cursor, list_map +from rewrite.java import J, Space, Comment, TextComment +from rewrite.python import PythonVisitor, PySpace, GeneralFormatStyle, PyComment +from rewrite.visitor import T + +J2 = TypeVar('J2', bound=J) + + +class NormalizeLineBreaksVisitor(PythonVisitor): + def __init__(self, style: GeneralFormatStyle, stop_after: Tree = None): + self._stop_after = stop_after + self._stop = False + self._style = style + + def visit_space(self, space: Optional[Space], loc: Optional[Union[PySpace.Location, Space.Location]], + p: P) -> Space: + if not space or space is Space.EMPTY or not space.whitespace: + return space + s = space.with_whitespace(_normalize_new_lines(space.whitespace, self._style.use_crlf_new_lines)) + + def process_comment(comment: Comment) -> Comment: + if comment.multiline: + if isinstance(comment, PyComment): + comment = comment.with_suffix(_normalize_new_lines(comment.suffix, self._style.use_crlf_new_lines)) + # TODO: Call PyComment Visitor, but this is not implemented yet.... + return comment + elif isinstance(comment, TextComment): + comment = comment.with_text(_normalize_new_lines(comment.text, self._style.use_crlf_new_lines)) + + return comment.with_suffix(_normalize_new_lines(comment.suffix, self._style.use_crlf_new_lines)) + + return s.with_comments(list_map(process_comment, s.comments)) + + def post_visit(self, tree: T, _: object) -> Optional[T]: + if self._stop_after and tree == self._stop_after: + self._stop = True + return tree + + def visit(self, tree: Optional[Tree], p: P, parent: Optional[Cursor] = None) -> Optional[T]: + return tree if self._stop else super().visit(tree, p, parent) + + +STR = TypeVar('STR', bound=Optional[str]) + + +def _normalize_new_lines(text: STR, use_crlf: bool) -> STR: + """ + Normalize the line breaks in the given text to either use of CRLF or LF. + + :param text: The text to normalize. + :param use_crlf: Whether to use CRLF line breaks. + :return: The text with normalized line breaks. + """ + if text is None or '\n' not in text: + return text + + normalized = [] + for i, c in enumerate(text): + if use_crlf and c == '\n' and (i == 0 or text[i - 1] != '\r'): + normalized.append('\r\n') + elif use_crlf or c != '\r': + normalized.append(c) + return ''.join(normalized) diff --git a/rewrite/rewrite/python/style.py b/rewrite/rewrite/python/style.py index 1c470ffb..57129682 100644 --- a/rewrite/rewrite/python/style.py +++ b/rewrite/rewrite/python/style.py @@ -1,6 +1,5 @@ from __future__ import annotations -from abc import ABC from dataclasses import dataclass, replace from ..style import Style, NamedStyles @@ -450,6 +449,19 @@ def with_minimum(self, minimum: Minimum) -> BlankLinesStyle: return self if minimum is self._minimum else replace(self, _minimum=minimum) +@dataclass(frozen=True) +class GeneralFormatStyle(PythonStyle): + _use_crlf_new_lines: bool + + @property + def use_crlf_new_lines(self) -> bool: + return self._use_crlf_new_lines + + def with_use_crlf_new_lines(self, use_crlf_new_lines: bool) -> GeneralFormatStyle: + return self if use_crlf_new_lines is self._use_crlf_new_lines else replace(self, + _use_crlf_new_lines=use_crlf_new_lines) + + class IntelliJ(NamedStyles): @classmethod def spaces(cls) -> SpacesStyle: diff --git a/rewrite/tests/python/all/format/normalize_line_breaks_visitor_test.py b/rewrite/tests/python/all/format/normalize_line_breaks_visitor_test.py new file mode 100644 index 00000000..e99e3437 --- /dev/null +++ b/rewrite/tests/python/all/format/normalize_line_breaks_visitor_test.py @@ -0,0 +1,49 @@ +import unittest + +from rewrite.python import GeneralFormatStyle +from rewrite.python.format.normalize_line_breaks_visitor import NormalizeLineBreaksVisitor +from rewrite.test import from_visitor, RecipeSpec, rewrite_run, python + + +class TestNormalizeLineBreaksVisitor(unittest.TestCase): + + def setUp(self): + # language=python + self.windows = ( + "class Test:\r\n" + " # some comment\r\n" + " def test(self):\r\n" + " print()\r\n" + "\r\n" + ) + # language=python + self.linux = ( + "class Test:\n" + " # some comment\n" + " def test(self):\n" + " print()\n" + "\n" + ) + + @staticmethod + def normalize_line_breaks(use_crlf: bool) -> RecipeSpec: + style = GeneralFormatStyle(_use_crlf_new_lines=use_crlf) + return RecipeSpec().with_recipe(from_visitor(NormalizeLineBreaksVisitor(style))) + + def test_windows_to_linux(self): + rewrite_run( + python( + self.windows, + self.linux + ), + spec=self.normalize_line_breaks(use_crlf=False) + ) + + def test_linux_to_windows(self): + rewrite_run( + python( + self.linux, + self.windows + ), + spec=self.normalize_line_breaks(use_crlf=True) + )