Skip to content

Commit 3dafd4c

Browse files
authored
Fix IOBasePayload reading entire files into memory instead of chunking (#11139)
1 parent 6cb2244 commit 3dafd4c

File tree

3 files changed

+164
-3
lines changed

3 files changed

+164
-3
lines changed

CHANGES/11138.bugfix.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Fixed ``IOBasePayload`` and ``TextIOPayload`` reading entire files into memory when streaming large files -- by :user:`bdraco`.
2+
3+
When using file-like objects with the aiohttp client, the entire file would be read into memory if the file size was provided in the ``Content-Length`` header. This could cause out-of-memory errors when uploading large files. The payload classes now correctly read data in chunks of ``READ_SIZE`` (64KB) regardless of the total content length.

aiohttp/payload.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ def _read_and_available_len(
512512
self._set_or_restore_start_position()
513513
size = self.size # Call size only once since it does I/O
514514
return size, self._value.read(
515-
min(size or READ_SIZE, remaining_content_len or READ_SIZE)
515+
min(READ_SIZE, size or READ_SIZE, remaining_content_len or READ_SIZE)
516516
)
517517

518518
def _read(self, remaining_content_len: Optional[int]) -> bytes:
@@ -615,7 +615,15 @@ async def write_with_length(
615615
return
616616

617617
# Read next chunk
618-
chunk = await loop.run_in_executor(None, self._read, remaining_content_len)
618+
chunk = await loop.run_in_executor(
619+
None,
620+
self._read,
621+
(
622+
min(READ_SIZE, remaining_content_len)
623+
if remaining_content_len is not None
624+
else READ_SIZE
625+
),
626+
)
619627

620628
def _should_stop_writing(
621629
self,
@@ -757,7 +765,7 @@ def _read_and_available_len(
757765
self._set_or_restore_start_position()
758766
size = self.size
759767
chunk = self._value.read(
760-
min(size or READ_SIZE, remaining_content_len or READ_SIZE)
768+
min(READ_SIZE, size or READ_SIZE, remaining_content_len or READ_SIZE)
761769
)
762770
return size, chunk.encode(self._encoding) if self._encoding else chunk.encode()
763771

tests/test_payload.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from aiohttp import payload
1414
from aiohttp.abc import AbstractStreamWriter
15+
from aiohttp.payload import READ_SIZE
1516

1617

1718
class BufferWriter(AbstractStreamWriter):
@@ -363,6 +364,155 @@ async def test_iobase_payload_exact_chunk_size_limit() -> None:
363364
assert written == data[:chunk_size]
364365

365366

367+
async def test_iobase_payload_reads_in_chunks() -> None:
368+
"""Test IOBasePayload reads data in chunks of READ_SIZE, not all at once."""
369+
# Create a large file that's multiple times larger than READ_SIZE
370+
large_data = b"x" * (READ_SIZE * 3 + 1000) # ~192KB + 1000 bytes
371+
372+
# Mock the file-like object to track read calls
373+
mock_file = unittest.mock.Mock(spec=io.BytesIO)
374+
mock_file.tell.return_value = 0
375+
mock_file.fileno.side_effect = AttributeError # Make size return None
376+
377+
# Track the sizes of read() calls
378+
read_sizes = []
379+
380+
def mock_read(size: int) -> bytes:
381+
read_sizes.append(size)
382+
# Return data based on how many times read was called
383+
call_count = len(read_sizes)
384+
if call_count == 1:
385+
return large_data[:size]
386+
elif call_count == 2:
387+
return large_data[READ_SIZE : READ_SIZE + size]
388+
elif call_count == 3:
389+
return large_data[READ_SIZE * 2 : READ_SIZE * 2 + size]
390+
else:
391+
return large_data[READ_SIZE * 3 :]
392+
393+
mock_file.read.side_effect = mock_read
394+
395+
payload_obj = payload.IOBasePayload(mock_file)
396+
writer = MockStreamWriter()
397+
398+
# Write with a large content_length
399+
await payload_obj.write_with_length(writer, len(large_data))
400+
401+
# Verify that reads were limited to READ_SIZE
402+
assert len(read_sizes) > 1 # Should have multiple reads
403+
for read_size in read_sizes:
404+
assert (
405+
read_size <= READ_SIZE
406+
), f"Read size {read_size} exceeds READ_SIZE {READ_SIZE}"
407+
408+
409+
async def test_iobase_payload_large_content_length() -> None:
410+
"""Test IOBasePayload with very large content_length doesn't read all at once."""
411+
data = b"x" * (READ_SIZE + 1000)
412+
413+
# Create a custom file-like object that tracks read sizes
414+
class TrackingBytesIO(io.BytesIO):
415+
def __init__(self, data: bytes) -> None:
416+
super().__init__(data)
417+
self.read_sizes: List[int] = []
418+
419+
def read(self, size: Optional[int] = -1) -> bytes:
420+
self.read_sizes.append(size if size is not None else -1)
421+
return super().read(size)
422+
423+
tracking_file = TrackingBytesIO(data)
424+
payload_obj = payload.IOBasePayload(tracking_file)
425+
writer = MockStreamWriter()
426+
427+
# Write with a very large content_length (simulating the bug scenario)
428+
large_content_length = 10 * 1024 * 1024 # 10MB
429+
await payload_obj.write_with_length(writer, large_content_length)
430+
431+
# Verify no single read exceeded READ_SIZE
432+
for read_size in tracking_file.read_sizes:
433+
assert (
434+
read_size <= READ_SIZE
435+
), f"Read size {read_size} exceeds READ_SIZE {READ_SIZE}"
436+
437+
# Verify the correct amount of data was written
438+
assert writer.get_written_bytes() == data
439+
440+
441+
async def test_textio_payload_reads_in_chunks() -> None:
442+
"""Test TextIOPayload reads data in chunks of READ_SIZE, not all at once."""
443+
# Create a large text file that's multiple times larger than READ_SIZE
444+
large_text = "x" * (READ_SIZE * 3 + 1000) # ~192KB + 1000 chars
445+
446+
# Mock the file-like object to track read calls
447+
mock_file = unittest.mock.Mock(spec=io.StringIO)
448+
mock_file.tell.return_value = 0
449+
mock_file.fileno.side_effect = AttributeError # Make size return None
450+
mock_file.encoding = "utf-8"
451+
452+
# Track the sizes of read() calls
453+
read_sizes = []
454+
455+
def mock_read(size: int) -> str:
456+
read_sizes.append(size)
457+
# Return data based on how many times read was called
458+
call_count = len(read_sizes)
459+
if call_count == 1:
460+
return large_text[:size]
461+
elif call_count == 2:
462+
return large_text[READ_SIZE : READ_SIZE + size]
463+
elif call_count == 3:
464+
return large_text[READ_SIZE * 2 : READ_SIZE * 2 + size]
465+
else:
466+
return large_text[READ_SIZE * 3 :]
467+
468+
mock_file.read.side_effect = mock_read
469+
470+
payload_obj = payload.TextIOPayload(mock_file)
471+
writer = MockStreamWriter()
472+
473+
# Write with a large content_length
474+
await payload_obj.write_with_length(writer, len(large_text.encode("utf-8")))
475+
476+
# Verify that reads were limited to READ_SIZE
477+
assert len(read_sizes) > 1 # Should have multiple reads
478+
for read_size in read_sizes:
479+
assert (
480+
read_size <= READ_SIZE
481+
), f"Read size {read_size} exceeds READ_SIZE {READ_SIZE}"
482+
483+
484+
async def test_textio_payload_large_content_length() -> None:
485+
"""Test TextIOPayload with very large content_length doesn't read all at once."""
486+
text_data = "x" * (READ_SIZE + 1000)
487+
488+
# Create a custom file-like object that tracks read sizes
489+
class TrackingStringIO(io.StringIO):
490+
def __init__(self, data: str) -> None:
491+
super().__init__(data)
492+
self.read_sizes: List[int] = []
493+
494+
def read(self, size: Optional[int] = -1) -> str:
495+
self.read_sizes.append(size if size is not None else -1)
496+
return super().read(size)
497+
498+
tracking_file = TrackingStringIO(text_data)
499+
payload_obj = payload.TextIOPayload(tracking_file)
500+
writer = MockStreamWriter()
501+
502+
# Write with a very large content_length (simulating the bug scenario)
503+
large_content_length = 10 * 1024 * 1024 # 10MB
504+
await payload_obj.write_with_length(writer, large_content_length)
505+
506+
# Verify no single read exceeded READ_SIZE
507+
for read_size in tracking_file.read_sizes:
508+
assert (
509+
read_size <= READ_SIZE
510+
), f"Read size {read_size} exceeds READ_SIZE {READ_SIZE}"
511+
512+
# Verify the correct amount of data was written
513+
assert writer.get_written_bytes() == text_data.encode("utf-8")
514+
515+
366516
async def test_async_iterable_payload_write_with_length_no_limit() -> None:
367517
"""Test AsyncIterablePayload writing with no content length limit."""
368518

0 commit comments

Comments
 (0)