Skip to content
101 changes: 60 additions & 41 deletions gridfs/asynchronous/grid_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,24 +1176,6 @@ def __getattr__(self, name: str) -> Any:
raise AttributeError("GridIn object has no attribute '%s'" % name)

def __setattr__(self, name: str, value: Any) -> None:
# For properties of this instance like _buffer, or descriptors set on
# the class like filename, use regular __setattr__
if name in self.__dict__ or name in self.__class__.__dict__:
object.__setattr__(self, name, value)
else:
if _IS_SYNC:
# All other attributes are part of the document in db.fs.files.
# Store them to be sent to server on close() or if closed, send
# them now.
self._file[name] = value
if self._closed:
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})
else:
raise AttributeError(
"AsyncGridIn does not support __setattr__. Use AsyncGridIn.set() instead"
)

async def set(self, name: str, value: Any) -> None:
# For properties of this instance like _buffer, or descriptors set on
# the class like filename, use regular __setattr__
if name in self.__dict__ or name in self.__class__.__dict__:
Expand All @@ -1204,9 +1186,17 @@ async def set(self, name: str, value: Any) -> None:
# them now.
self._file[name] = value
if self._closed:
await self._coll.files.update_one(
{"_id": self._file["_id"]}, {"$set": {name: value}}
)
if _IS_SYNC:
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})
else:
raise AttributeError(
"AsyncGridIn does not support __setattr__ after being closed(). Set the attribute before closing the file or use AsyncGridIn.set() instead"
)

async def set(self, name: str, value: Any) -> None:
self._file[name] = value
if self._closed:
await self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})

async def _flush_data(self, data: Any, force: bool = False) -> None:
"""Flush `data` to a chunk."""
Expand Down Expand Up @@ -1400,7 +1390,11 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Any:
return False


class AsyncGridOut(io.IOBase):
GRIDOUT_BASE_CLASS = io.IOBase if _IS_SYNC else object # type: Any


class AsyncGridOut(GRIDOUT_BASE_CLASS): # type: ignore

"""Class to read data out of GridFS."""

def __init__(
Expand Down Expand Up @@ -1460,6 +1454,8 @@ def __init__(
self._position = 0
self._file = file_document
self._session = session
if not _IS_SYNC:
self.closed = False

_id: Any = _a_grid_out_property("_id", "The ``'_id'`` value for this file.")
filename: str = _a_grid_out_property("filename", "Name of this file.")
Expand All @@ -1486,16 +1482,43 @@ def __init__(
_file: Any
_chunk_iter: Any

async def __anext__(self) -> bytes:
return super().__next__()
if not _IS_SYNC:
closed: bool

def __next__(self) -> bytes: # noqa: F811, RUF100
if _IS_SYNC:
return super().__next__()
else:
raise TypeError(
"AsyncGridOut does not support synchronous iteration. Use `async for` instead"
)
async def __anext__(self) -> bytes:
line = await self.readline()
if line:
return line
raise StopAsyncIteration()

async def to_list(self) -> list[bytes]:
return [x async for x in self] # noqa: C416, RUF100

async def readline(self, size: int = -1) -> bytes:
"""Read one line or up to `size` bytes from the file.

:param size: the maximum number of bytes to read
"""
return await self._read_size_or_line(size=size, line=True)

async def readlines(self, size: int = -1) -> list[bytes]:
"""Read one line or up to `size` bytes from the file.

:param size: the maximum number of bytes to read
"""
await self.open()
lines = []
remainder = int(self.length) - self._position
bytes_read = 0
while remainder > 0:
line = await self._read_size_or_line(line=True)
bytes_read += len(line)
lines.append(line)
remainder = int(self.length) - self._position
if 0 < size < bytes_read:
break

return lines

async def open(self) -> None:
if not self._file:
Expand Down Expand Up @@ -1616,18 +1639,11 @@ async def read(self, size: int = -1) -> bytes:
"""
return await self._read_size_or_line(size=size)

async def readline(self, size: int = -1) -> bytes: # type: ignore[override]
"""Read one line or up to `size` bytes from the file.

:param size: the maximum number of bytes to read
"""
return await self._read_size_or_line(size=size, line=True)

def tell(self) -> int:
"""Return the current position of this file."""
return self._position

async def seek(self, pos: int, whence: int = _SEEK_SET) -> int: # type: ignore[override]
async def seek(self, pos: int, whence: int = _SEEK_SET) -> int:
"""Set the current position of this file.

:param pos: the position (or offset if using relative
Expand Down Expand Up @@ -1690,12 +1706,15 @@ def __aiter__(self) -> AsyncGridOut:
"""
return self

async def close(self) -> None: # type: ignore[override]
async def close(self) -> None:
"""Make GridOut more generically file-like."""
if self._chunk_iter:
await self._chunk_iter.close()
self._chunk_iter = None
super().close()
if _IS_SYNC:
super().close()
else:
self.closed = True

def write(self, value: Any) -> NoReturn:
raise io.UnsupportedOperation("write")
Expand Down
97 changes: 60 additions & 37 deletions gridfs/synchronous/grid_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,24 +1166,6 @@ def __getattr__(self, name: str) -> Any:
raise AttributeError("GridIn object has no attribute '%s'" % name)

def __setattr__(self, name: str, value: Any) -> None:
# For properties of this instance like _buffer, or descriptors set on
# the class like filename, use regular __setattr__
if name in self.__dict__ or name in self.__class__.__dict__:
object.__setattr__(self, name, value)
else:
if _IS_SYNC:
# All other attributes are part of the document in db.fs.files.
# Store them to be sent to server on close() or if closed, send
# them now.
self._file[name] = value
if self._closed:
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})
else:
raise AttributeError(
"GridIn does not support __setattr__. Use GridIn.set() instead"
)

def set(self, name: str, value: Any) -> None:
# For properties of this instance like _buffer, or descriptors set on
# the class like filename, use regular __setattr__
if name in self.__dict__ or name in self.__class__.__dict__:
Expand All @@ -1194,7 +1176,17 @@ def set(self, name: str, value: Any) -> None:
# them now.
self._file[name] = value
if self._closed:
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})
if _IS_SYNC:
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})
else:
raise AttributeError(
"GridIn does not support __setattr__ after being closed(). Set the attribute before closing the file or use GridIn.set() instead"
)

def set(self, name: str, value: Any) -> None:
self._file[name] = value
if self._closed:
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})

def _flush_data(self, data: Any, force: bool = False) -> None:
"""Flush `data` to a chunk."""
Expand Down Expand Up @@ -1388,7 +1380,11 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Any:
return False


class GridOut(io.IOBase):
GRIDOUT_BASE_CLASS = io.IOBase if _IS_SYNC else object # type: Any


class GridOut(GRIDOUT_BASE_CLASS): # type: ignore

"""Class to read data out of GridFS."""

def __init__(
Expand Down Expand Up @@ -1448,6 +1444,8 @@ def __init__(
self._position = 0
self._file = file_document
self._session = session
if not _IS_SYNC:
self.closed = False

_id: Any = _grid_out_property("_id", "The ``'_id'`` value for this file.")
filename: str = _grid_out_property("filename", "Name of this file.")
Expand All @@ -1474,14 +1472,43 @@ def __init__(
_file: Any
_chunk_iter: Any

def __next__(self) -> bytes:
return super().__next__()
if not _IS_SYNC:
closed: bool

def __next__(self) -> bytes: # noqa: F811, RUF100
if _IS_SYNC:
return super().__next__()
else:
raise TypeError("GridOut does not support synchronous iteration. Use `for` instead")
def __next__(self) -> bytes:
line = self.readline()
if line:
return line
raise StopIteration()

def to_list(self) -> list[bytes]:
return [x for x in self] # noqa: C416, RUF100

def readline(self, size: int = -1) -> bytes:
"""Read one line or up to `size` bytes from the file.

:param size: the maximum number of bytes to read
"""
return self._read_size_or_line(size=size, line=True)

def readlines(self, size: int = -1) -> list[bytes]:
"""Read one line or up to `size` bytes from the file.

:param size: the maximum number of bytes to read
"""
self.open()
lines = []
remainder = int(self.length) - self._position
bytes_read = 0
while remainder > 0:
line = self._read_size_or_line(line=True)
bytes_read += len(line)
lines.append(line)
remainder = int(self.length) - self._position
if 0 < size < bytes_read:
break

return lines

def open(self) -> None:
if not self._file:
Expand Down Expand Up @@ -1602,18 +1629,11 @@ def read(self, size: int = -1) -> bytes:
"""
return self._read_size_or_line(size=size)

def readline(self, size: int = -1) -> bytes: # type: ignore[override]
"""Read one line or up to `size` bytes from the file.

:param size: the maximum number of bytes to read
"""
return self._read_size_or_line(size=size, line=True)

def tell(self) -> int:
"""Return the current position of this file."""
return self._position

def seek(self, pos: int, whence: int = _SEEK_SET) -> int: # type: ignore[override]
def seek(self, pos: int, whence: int = _SEEK_SET) -> int:
"""Set the current position of this file.

:param pos: the position (or offset if using relative
Expand Down Expand Up @@ -1676,12 +1696,15 @@ def __iter__(self) -> GridOut:
"""
return self

def close(self) -> None: # type: ignore[override]
def close(self) -> None:
"""Make GridOut more generically file-like."""
if self._chunk_iter:
self._chunk_iter.close()
self._chunk_iter = None
super().close()
if _IS_SYNC:
super().close()
else:
self.closed = True

def write(self, value: Any) -> NoReturn:
raise io.UnsupportedOperation("write")
Expand Down
5 changes: 5 additions & 0 deletions pymongo/asynchronous/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,13 @@ async def inner(*args: Any, **kwargs: Any) -> Any:

if sys.version_info >= (3, 10):
anext = builtins.anext
aiter = builtins.aiter
else:

async def anext(cls: Any) -> Any:
"""Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#anext."""
return await cls.__anext__()

def aiter(cls: Any) -> Any:
"""Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#anext."""
return cls.__aiter__()
2 changes: 1 addition & 1 deletion pymongo/asynchronous/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ async def _process_change(
if server:
await server.pool.reset(interrupt_connections=interrupt_connections)

# Wake waiters in select_servers().
# Wake anything waiting in select_servers().
self._condition.notify_all()

async def on_change(
Expand Down
5 changes: 5 additions & 0 deletions pymongo/synchronous/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,13 @@ def inner(*args: Any, **kwargs: Any) -> Any:

if sys.version_info >= (3, 10):
next = builtins.next
iter = builtins.iter
else:

def next(cls: Any) -> Any:
"""Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#next."""
return cls.__next__()

def iter(cls: Any) -> Any:
"""Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#next."""
return cls.__iter__()
2 changes: 1 addition & 1 deletion pymongo/synchronous/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def _process_change(
if server:
server.pool.reset(interrupt_connections=interrupt_connections)

# Wake waiters in select_servers().
# Wake anything waiting in select_servers().
self._condition.notify_all()

def on_change(
Expand Down
4 changes: 2 additions & 2 deletions test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,11 +947,11 @@ def tearDownClass(cls):

@classmethod
def _setup_class(cls):
cls._setup_class()
pass

@classmethod
def _tearDown_class(cls):
cls._tearDown_class()
pass


class IntegrationTest(PyMongoTestCase):
Expand Down
4 changes: 2 additions & 2 deletions test/asynchronous/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,11 +949,11 @@ def tearDownClass(cls):

@classmethod
async def _setup_class(cls):
await cls._setup_class()
pass

@classmethod
async def _tearDown_class(cls):
await cls._tearDown_class()
pass


class AsyncIntegrationTest(AsyncPyMongoTestCase):
Expand Down
Loading
Loading