Skip to content
43 changes: 43 additions & 0 deletions Lib/test/test_io/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,49 @@ def test_RawIOBase_read(self):
self.assertEqual(rawio.read(2), None)
self.assertEqual(rawio.read(2), b"")

def test_exact_RawIOBase(self):
rawio = self.MockRawIOWithoutRead((b"ab", b"cd"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I wasn't precise enough. RawIOBase.read() is implemented in terms of readinto. So, what we want to do is provide a readinto method that does something bad as it's given the internal buffer we constructed in C. I don't know if it's possible to cause a segfault on main with that approach though.

Stated otherwise, we want some class:

class EvilReadInto(MockRawIOWithoutRead):
    def readinto(self, buf):
        # do something bad with 'buf'

And then

r = EvilReadInto(something_here)
r.read()  # on main, this call should crash, but not with this patch

buf = bytearray(2)
n = rawio.readinto(buf)
self.assertEqual(n, 2)
self.assertEqual(buf, b"ab")

n = rawio.readinto(buf)
self.assertEqual(n, 2)
self.assertEqual(buf, b"cd")

n = rawio.readinto(buf)
self.assertEqual(n, 0)
self.assertEqual(buf, b"cd")

def test_partial_readinto_write_RawIOBase(self):
rawio = self.MockRawIOWithoutRead((b"abcdef",))
buf = bytearray(3)
n = rawio.readinto(buf)
self.assertEqual(n, 3)
self.assertEqual(buf, b"abc")

n2 = rawio.readinto(buf)
self.assertEqual(n2, 3)
self.assertEqual(buf, b"def")

def test_readinto_none_RawIOBase(self):
rawio = self.MockRawIOWithoutRead((None, b"x"))
buf = bytearray(2)
n = rawio.readinto(buf)
self.assertIsNone(n)

n2 = rawio.readinto(buf)
self.assertEqual(n2, 1)
self.assertEqual(buf[0], ord('x'))

def test_read_default_via_readinto_RawIOBase(self):
rawio = self.MockRawIOWithoutRead((b"abcdef",))
result = rawio.read(4)
self.assertEqual(result, b"abcd")
result2 = rawio.read(4)
self.assertEqual(result2, b"ef")

def test_types_have_dict(self):
test = (
self.IOBase(),
Expand Down
Loading