diff --git a/stdlib/@tests/test_cases/check_io.py b/stdlib/@tests/test_cases/check_io.py index ce8c34aedbad..c3ee5aa4c6c0 100644 --- a/stdlib/@tests/test_cases/check_io.py +++ b/stdlib/@tests/test_cases/check_io.py @@ -1,9 +1,17 @@ -from _io import BufferedReader +from _io import BufferedReader, BufferedRWPair, BufferedWriter from gzip import GzipFile from io import FileIO, RawIOBase, TextIOWrapper +from socket import SocketIO +from typing import Any from typing_extensions import assert_type +socket: Any = None + BufferedReader(RawIOBase()) +BufferedWriter(RawIOBase()) +BufferedWriter(SocketIO(socket, "r")) + +BufferedRWPair(open("", "rb"), open("", "wb")) assert_type(TextIOWrapper(FileIO("")).buffer, FileIO) assert_type(TextIOWrapper(FileIO(13)).detach(), FileIO) diff --git a/stdlib/_io.pyi b/stdlib/_io.pyi index 2d2a60e4dddf..7a8b49fa6203 100644 --- a/stdlib/_io.pyi +++ b/stdlib/_io.pyi @@ -163,13 +163,39 @@ class BufferedReader(BufferedIOBase, _BufferedIOBase, BinaryIO, Generic[_Buffere def seek(self, target: int, whence: int = 0, /) -> int: ... def truncate(self, pos: int | None = None, /) -> int: ... +@type_check_only +class _BufferedWriterStream(Protocol): + def write(self, b: WriteableBuffer, /) -> int | None: ... + def seek(self, pos: int, whence: int, /) -> int: ... + def tell(self) -> int: ... + def truncate(self, size: int, /) -> int: ... + def flush(self) -> object: ... + def close(self) -> object: ... + @property + def closed(self) -> bool: ... + def writable(self) -> bool: ... + def seekable(self) -> bool: ... + + # The following methods just pass through to the underlying stream. Since + # not all streams support them, they are marked as optional here, and will + # raise an AttributeError if called on a stream that does not support them. + + # @property + # def name(self) -> Any: ... # Type is inconsistent between the various I/O types. + # @property + # def mode(self) -> str: ... + # def fileno(self) -> int: ... + # def isatty(self) -> bool: ... + +_BufferedWriterStreamT = TypeVar("_BufferedWriterStreamT", bound=_BufferedWriterStream, default=_BufferedWriterStream) + @disjoint_base -class BufferedWriter(BufferedIOBase, _BufferedIOBase, BinaryIO): # type: ignore[misc] # incompatible definitions of writelines in the base classes - raw: RawIOBase +class BufferedWriter(BufferedIOBase, _BufferedIOBase, BinaryIO, Generic[_BufferedWriterStreamT]): # type: ignore[misc] # incompatible definitions of writelines in the base classes + raw: _BufferedWriterStreamT if sys.version_info >= (3, 14): - def __init__(self, raw: RawIOBase, buffer_size: int = 131072) -> None: ... + def __init__(self, raw: _BufferedWriterStreamT, buffer_size: int = 131072) -> None: ... else: - def __init__(self, raw: RawIOBase, buffer_size: int = 8192) -> None: ... + def __init__(self, raw: _BufferedWriterStreamT, buffer_size: int = 8192) -> None: ... def write(self, buffer: ReadableBuffer, /) -> int: ... def seek(self, target: int, whence: int = 0, /) -> int: ... @@ -190,11 +216,15 @@ class BufferedRandom(BufferedIOBase, _BufferedIOBase, BinaryIO): # type: ignore def truncate(self, pos: int | None = None, /) -> int: ... @disjoint_base -class BufferedRWPair(BufferedIOBase, _BufferedIOBase, Generic[_BufferedReaderStreamT]): +class BufferedRWPair(BufferedIOBase, _BufferedIOBase, Generic[_BufferedReaderStreamT, _BufferedWriterStreamT]): if sys.version_info >= (3, 14): - def __init__(self, reader: _BufferedReaderStreamT, writer: RawIOBase, buffer_size: int = 131072, /) -> None: ... + def __init__( + self, reader: _BufferedReaderStreamT, writer: _BufferedWriterStreamT, buffer_size: int = 131072, / + ) -> None: ... else: - def __init__(self, reader: _BufferedReaderStreamT, writer: RawIOBase, buffer_size: int = 8192, /) -> None: ... + def __init__( + self, reader: _BufferedReaderStreamT, writer: _BufferedWriterStreamT, buffer_size: int = 8192, / + ) -> None: ... def peek(self, size: int = 0, /) -> bytes: ...