diff --git a/CHANGELOG.md b/CHANGELOG.md index 065e5a57..75f5fa15 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ Changelog NOTE: isort follows the [semver](https://semver.org/) versioning standard. Find out more about isort's release policy [here](https://pycqa.github.io/isort/docs/major_releases/release_policy). +### 5.13.3 (Unreleased) + + - Fixed #2393: Calling isort with --sort-reexports with input from stdin fails due to non-seekable streams @jasur-py + ### 5.13.2 December 13 2023 - Apply the bracket fix from issue #471 only for use_parentheses=True (#2184) @bp72 diff --git a/isort/core.py b/isort/core.py index bff28458..90c9acd0 100644 --- a/isort/core.py +++ b/isort/core.py @@ -1,7 +1,8 @@ -import textwrap from io import StringIO +import textwrap from itertools import chain from typing import List, TextIO, Union +import sys import isort.literal from isort.settings import DEFAULT_CONFIG, Config @@ -52,6 +53,38 @@ def process( Returns `True` if there were changes that needed to be made (errors present) from what was provided in the input_stream, otherwise `False`. """ + # Check if output stream is seekable for reexport handling + output_seekable = False + # Explicitly treat sys.stdout and sys.stderr as non-seekable + if output_stream in (sys.stdout, sys.stderr): + output_seekable = False + elif hasattr(output_stream, 'seekable'): + try: + output_seekable = output_stream.seekable() + if output_seekable: + # Try a test seek to see if it actually works + pos = output_stream.tell() + try: + output_stream.seek(pos) + except Exception: + output_seekable = False + except Exception: + output_seekable = False + elif all(hasattr(output_stream, attr) for attr in ('seek', 'tell', 'truncate')): + try: + pos = output_stream.tell() + output_stream.seek(pos) + output_seekable = True + except Exception: + output_seekable = False + # Use internal buffer if output stream is not seekable and we might need reexport sorting + internal_output = None + if not output_seekable and config.sort_reexports: + internal_output = StringIO() + _output_stream = internal_output + else: + _output_stream = output_stream + line_separator: str = config.line_ending add_imports: List[str] = [format_natural(addition) for addition in config.add_imports] import_section: str = "" @@ -134,28 +167,44 @@ def process( if not line_separator: line_separator = "\n" - if code_sorting and code_sorting_section: - if is_reexport: - output_stream.seek(output_stream.tell() - reexport_rollback) + if code_sorting and code_sorting_section and is_reexport: + if output_seekable: + _output_stream.seek(_output_stream.tell() - reexport_rollback) reexport_rollback = 0 - sorted_code = textwrap.indent( - isort.literal.assignment( - code_sorting_section, - str(code_sorting), - extension, - config=_indented_config(config, indent), - ), - code_sorting_indent, - ) - made_changes = made_changes or _has_changed( - before=code_sorting_section, - after=sorted_code, - line_separator=line_separator, - ignore_whitespace=config.ignore_whitespace, - ) - output_stream.write(sorted_code) - if is_reexport: - output_stream.truncate() + else: + if not output_seekable and reexport_rollback > 0: + current_value = _output_stream.getvalue() + # Find the last occurrence of '__all__' and truncate to its index + idx = current_value.rfind('__all__') + if idx != -1: + # Truncate to the start of the line containing __all__ + line_start = current_value.rfind('\n', 0, idx) + if line_start == -1: + line_start = 0 + else: + line_start += 1 + _output_stream.seek(0) + _output_stream.truncate(0) + _output_stream.write(current_value[:line_start]) + reexport_rollback = 0 + sorted_code = textwrap.indent( + isort.literal.assignment( + code_sorting_section, + str(code_sorting), + extension, + config=_indented_config(config, indent), + ), + code_sorting_indent, + ) + made_changes = made_changes or _has_changed( + before=code_sorting_section, + after=sorted_code, + line_separator=line_separator, + ignore_whitespace=config.ignore_whitespace, + ) + _output_stream.write(sorted_code) + if is_reexport: + _output_stream.truncate() else: stripped_line = line.strip() if stripped_line and not line_separator: @@ -239,6 +288,21 @@ def process( code_sorting_indent = line[: -len(line.lstrip())] not_imports = True code_sorting_section += line + if is_reexport and not output_seekable and reexport_rollback > 0: + current_value = _output_stream.getvalue() + # Find the last occurrence of '__all__' and truncate to its index + idx = current_value.rfind('__all__') + if idx != -1: + # Truncate to the start of the line containing __all__ + line_start = current_value.rfind('\n', 0, idx) + if line_start == -1: + line_start = 0 + else: + line_start += 1 + _output_stream.seek(0) + _output_stream.truncate(0) + _output_stream.write(current_value[:line_start]) + reexport_rollback = 0 reexport_rollback = len(line) is_reexport = True elif code_sorting: @@ -259,11 +323,29 @@ def process( ignore_whitespace=config.ignore_whitespace, ) if is_reexport: - output_stream.seek(output_stream.tell() - reexport_rollback) - reexport_rollback = 0 - output_stream.write(sorted_code) + if output_seekable: + _output_stream.seek(_output_stream.tell() - reexport_rollback) + reexport_rollback = 0 + else: + if not output_seekable and reexport_rollback > 0: + current_value = _output_stream.getvalue() + # Find the last occurrence of '__all__' + # and truncate to its index + idx = current_value.rfind('__all__') + if idx != -1: + # Truncate to the start of the line containing __all__ + line_start = current_value.rfind('\n', 0, idx) + if line_start == -1: + line_start = 0 + else: + line_start += 1 + _output_stream.seek(0) + _output_stream.truncate(0) + _output_stream.write(current_value[:line_start]) + reexport_rollback = 0 + _output_stream.write(sorted_code) if is_reexport: - output_stream.truncate() + _output_stream.truncate() not_imports = True code_sorting = False code_sorting_section = "" @@ -277,7 +359,7 @@ def process( or stripped_line in config.section_comments_end ): if import_section and not contains_imports: - output_stream.write(import_section) + _output_stream.write(import_section) import_section = line not_imports = False else: @@ -367,7 +449,7 @@ def process( lines_before += line continue if not import_section: - output_stream.write("".join(lines_before)) + _output_stream.write("".join(lines_before)) lines_before = [] raw_import_section: str = import_section @@ -384,7 +466,7 @@ def process( add_line_separator = line_separator or "\n" import_section = add_line_separator.join(add_imports) + add_line_separator if end_of_file and index != 0: - output_stream.write(add_line_separator) + _output_stream.write(add_line_separator) contains_imports = True add_imports = [] @@ -404,7 +486,7 @@ def process( import_section += line raw_import_section += line if not contains_imports: - output_stream.write(import_section) + _output_stream.write(import_section) else: leading_whitespace = import_section[: -len(import_section.lstrip())] @@ -444,12 +526,12 @@ def process( line_separator=line_separator, ignore_whitespace=config.ignore_whitespace, ) - output_stream.write(sorted_import_section) + _output_stream.write(sorted_import_section) if not line and not indent and next_import_section: - output_stream.write(line_separator) + _output_stream.write(line_separator) if indent: - output_stream.write(line) + _output_stream.write(line) if not next_import_section: indent = "" @@ -461,7 +543,7 @@ def process( import_section = next_import_section next_import_section = "" else: - output_stream.write(line) + _output_stream.write(line) not_imports = False if stripped_line and not in_quote and not import_section and not next_import_section: @@ -471,7 +553,7 @@ def process( if not new_line: break - output_stream.write(new_line) + _output_stream.write(new_line) stripped_line = new_line.strip().split("#")[0] if stripped_line.startswith(("raise", "yield")): @@ -480,13 +562,18 @@ def process( if not new_line: break - output_stream.write(new_line) + _output_stream.write(new_line) stripped_line = new_line.strip().split("#")[0] if made_changes and config.only_modified: for output_str in verbose_output: print(output_str) + # Write internal buffer to actual output stream if we used one + if internal_output is not None: + internal_output.seek(0) + output_stream.write(internal_output.read()) + return made_changes diff --git a/tests/unit/test_isort.py b/tests/unit/test_isort.py index e72f1fee..7eb93f16 100644 --- a/tests/unit/test_isort.py +++ b/tests/unit/test_isort.py @@ -5741,3 +5741,86 @@ def test_reexport_multiline_long_rollback() -> None: test """ assert isort.code(test_input, config=Config(sort_reexports=True)) == expd_output + + +def test_reexport_non_seekable_stream() -> None: + """Test that reexport sorting works with non-seekable streams like stdout""" + from io import StringIO + + test_input = """from test import B, A +__all__ = ["B", "A"]""" + + expected_output = """from test import A, B + +__all__ = ['A', 'B']""" + + # Test with a non-seekable stream (simulating stdout) + input_stream = StringIO(test_input) + output_stream = StringIO() + + # Mock sys.stdout to be non-seekable + original_stdout = sys.stdout + try: + sys.stdout = output_stream + api.sort_stream( + input_stream=input_stream, + output_stream=output_stream, + config=Config(sort_reexports=True), + ) + output_stream.seek(0) + result = output_stream.read() + assert result == expected_output + finally: + sys.stdout = original_stdout + +def test_reexport_non_seekable_stream() -> None: + """Test that reexport sorting works with non-seekable streams like stdout""" + from io import StringIO + + test_input = """from test import B, A +__all__ = ["B", "A"]""" + + expected_output = """from test import A, B + +__all__ = ['A', 'B']""" + + # Test with a non-seekable stream (simulating stdout) + input_stream = StringIO(test_input) + + # Create a non-seekable output stream that allows reading the result + class NonSeekableStream(StringIO): + def __init__(self): + super().__init__() + self._allow_seek = False + + def seek(self, *args, **kwargs): + if not self._allow_seek: + raise OSError("Stream is not seekable") + return super().seek(*args, **kwargs) + + def tell(self, *args, **kwargs): + if not self._allow_seek: + raise OSError("Stream is not seekable") + return super().tell(*args, **kwargs) + + def truncate(self, *args, **kwargs): + if not self._allow_seek: + raise OSError("Stream is not seekable") + return super().truncate(*args, **kwargs) + + def allow_seek(self): + self._allow_seek = True + + output_stream = NonSeekableStream() + + api.sort_stream( + input_stream=input_stream, + output_stream=output_stream, + config=Config(sort_reexports=True), + ) + + # Allow seeking to read the result + output_stream.allow_seek() + output_stream.seek(0) + result = output_stream.read() + assert result == expected_output