Skip to content

Commit 481600b

Browse files
terencehonlesShaneHarvey
authored andcommitted
PYTHON-1695 GridOut/GridIn more closely implement io.IOBase (#387)
Allows GridOut to be wrapped with zipfile.ZipFile from the stdlib.
1 parent 5950abf commit 481600b

File tree

3 files changed

+47
-0
lines changed

3 files changed

+47
-0
lines changed

doc/contributors.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,4 @@ The following is a list of people who have contributed to
8585
- Jagrut Trivedi(Jagrut)
8686
- Shrey Batra(shreybatra)
8787
- Felipe Rodrigues(fbidu)
88+
- Terence Honles (terencehonles)

gridfs/grid_file.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Tools for representing files stored in GridFS."""
1616
import datetime
1717
import hashlib
18+
import io
1819
import math
1920
import os
2021

@@ -312,6 +313,15 @@ def close(self):
312313
self.__flush()
313314
object.__setattr__(self, "_closed", True)
314315

316+
def read(self, size=-1):
317+
raise io.UnsupportedOperation('read')
318+
319+
def readable(self):
320+
return False
321+
322+
def seekable(self):
323+
return False
324+
315325
def write(self, data):
316326
"""Write data to the file. There is no return value.
317327
@@ -379,6 +389,9 @@ def writelines(self, sequence):
379389
for line in sequence:
380390
self.write(line)
381391

392+
def writeable(self):
393+
return True
394+
382395
def __enter__(self):
383396
"""Support for the context manager protocol.
384397
"""
@@ -472,6 +485,9 @@ def __getattr__(self, name):
472485
return self._file[name]
473486
raise AttributeError("GridOut object has no attribute '%s'" % name)
474487

488+
def readable(self):
489+
return True
490+
475491
def readchunk(self):
476492
"""Reads a chunk at a time. If the current position is within a
477493
chunk the remainder of the chunk is returned.
@@ -617,6 +633,9 @@ def seek(self, pos, whence=_SEEK_SET):
617633
self.__chunk_iter.close()
618634
self.__chunk_iter = None
619635

636+
def seekable(self):
637+
return True
638+
620639
def __iter__(self):
621640
"""Return an iterator over all of this file's data.
622641
@@ -625,6 +644,11 @@ def __iter__(self):
625644
useful when serving files using a webserver that handles
626645
such an iterator efficiently.
627646
647+
.. note::
648+
This is different from :py:class:`io.IOBase` which iterates over
649+
*lines* in the file. Use :meth:`GridOut.readline` to read line by
650+
line instead of chunk by chunk.
651+
628652
.. versionchanged:: 3.8
629653
The iterator now raises :class:`CorruptGridFile` when encountering
630654
any truncated, missing, or extra chunk in a file. The previous
@@ -639,6 +663,9 @@ def close(self):
639663
self.__chunk_iter.close()
640664
self.__chunk_iter = None
641665

666+
def write(self, value):
667+
raise io.UnsupportedOperation('write')
668+
642669
def __enter__(self):
643670
"""Makes it possible to use :class:`GridOut` files
644671
with the context manager protocol.

test/test_grid_file.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import datetime
2121
import sys
22+
import zipfile
2223
sys.path[0:0] = [""]
2324

2425
from bson.objectid import ObjectId
@@ -644,6 +645,24 @@ def test_survive_cursor_not_found(self):
644645
# Paranoid, ensure that a getMore was actually sent.
645646
self.assertIn("getMore", listener.started_command_names())
646647

648+
def test_zip(self):
649+
zf = StringIO()
650+
z = zipfile.ZipFile(zf, "w")
651+
z.writestr("test.txt", b"hello world")
652+
z.close()
653+
zf.seek(0)
654+
655+
f = GridIn(self.db.fs, filename="test.zip")
656+
f.write(zf)
657+
f.close()
658+
self.assertEqual(1, self.db.fs.files.count_documents({}))
659+
self.assertEqual(1, self.db.fs.chunks.count_documents({}))
660+
661+
g = GridOut(self.db.fs, f._id)
662+
z = zipfile.ZipFile(g)
663+
self.assertSequenceEqual(z.namelist(), ["test.txt"])
664+
self.assertEqual(z.read("test.txt"), b"hello world")
665+
647666

648667
if __name__ == "__main__":
649668
unittest.main()

0 commit comments

Comments
 (0)