|
12 | 12 |
|
13 | 13 | from aiohttp import payload
|
14 | 14 | from aiohttp.abc import AbstractStreamWriter
|
| 15 | +from aiohttp.payload import READ_SIZE |
15 | 16 |
|
16 | 17 |
|
17 | 18 | class BufferWriter(AbstractStreamWriter):
|
@@ -363,6 +364,155 @@ async def test_iobase_payload_exact_chunk_size_limit() -> None:
|
363 | 364 | assert written == data[:chunk_size]
|
364 | 365 |
|
365 | 366 |
|
| 367 | +async def test_iobase_payload_reads_in_chunks() -> None: |
| 368 | + """Test IOBasePayload reads data in chunks of READ_SIZE, not all at once.""" |
| 369 | + # Create a large file that's multiple times larger than READ_SIZE |
| 370 | + large_data = b"x" * (READ_SIZE * 3 + 1000) # ~192KB + 1000 bytes |
| 371 | + |
| 372 | + # Mock the file-like object to track read calls |
| 373 | + mock_file = unittest.mock.Mock(spec=io.BytesIO) |
| 374 | + mock_file.tell.return_value = 0 |
| 375 | + mock_file.fileno.side_effect = AttributeError # Make size return None |
| 376 | + |
| 377 | + # Track the sizes of read() calls |
| 378 | + read_sizes = [] |
| 379 | + |
| 380 | + def mock_read(size: int) -> bytes: |
| 381 | + read_sizes.append(size) |
| 382 | + # Return data based on how many times read was called |
| 383 | + call_count = len(read_sizes) |
| 384 | + if call_count == 1: |
| 385 | + return large_data[:size] |
| 386 | + elif call_count == 2: |
| 387 | + return large_data[READ_SIZE : READ_SIZE + size] |
| 388 | + elif call_count == 3: |
| 389 | + return large_data[READ_SIZE * 2 : READ_SIZE * 2 + size] |
| 390 | + else: |
| 391 | + return large_data[READ_SIZE * 3 :] |
| 392 | + |
| 393 | + mock_file.read.side_effect = mock_read |
| 394 | + |
| 395 | + payload_obj = payload.IOBasePayload(mock_file) |
| 396 | + writer = MockStreamWriter() |
| 397 | + |
| 398 | + # Write with a large content_length |
| 399 | + await payload_obj.write_with_length(writer, len(large_data)) |
| 400 | + |
| 401 | + # Verify that reads were limited to READ_SIZE |
| 402 | + assert len(read_sizes) > 1 # Should have multiple reads |
| 403 | + for read_size in read_sizes: |
| 404 | + assert ( |
| 405 | + read_size <= READ_SIZE |
| 406 | + ), f"Read size {read_size} exceeds READ_SIZE {READ_SIZE}" |
| 407 | + |
| 408 | + |
| 409 | +async def test_iobase_payload_large_content_length() -> None: |
| 410 | + """Test IOBasePayload with very large content_length doesn't read all at once.""" |
| 411 | + data = b"x" * (READ_SIZE + 1000) |
| 412 | + |
| 413 | + # Create a custom file-like object that tracks read sizes |
| 414 | + class TrackingBytesIO(io.BytesIO): |
| 415 | + def __init__(self, data: bytes) -> None: |
| 416 | + super().__init__(data) |
| 417 | + self.read_sizes: List[int] = [] |
| 418 | + |
| 419 | + def read(self, size: Optional[int] = -1) -> bytes: |
| 420 | + self.read_sizes.append(size if size is not None else -1) |
| 421 | + return super().read(size) |
| 422 | + |
| 423 | + tracking_file = TrackingBytesIO(data) |
| 424 | + payload_obj = payload.IOBasePayload(tracking_file) |
| 425 | + writer = MockStreamWriter() |
| 426 | + |
| 427 | + # Write with a very large content_length (simulating the bug scenario) |
| 428 | + large_content_length = 10 * 1024 * 1024 # 10MB |
| 429 | + await payload_obj.write_with_length(writer, large_content_length) |
| 430 | + |
| 431 | + # Verify no single read exceeded READ_SIZE |
| 432 | + for read_size in tracking_file.read_sizes: |
| 433 | + assert ( |
| 434 | + read_size <= READ_SIZE |
| 435 | + ), f"Read size {read_size} exceeds READ_SIZE {READ_SIZE}" |
| 436 | + |
| 437 | + # Verify the correct amount of data was written |
| 438 | + assert writer.get_written_bytes() == data |
| 439 | + |
| 440 | + |
| 441 | +async def test_textio_payload_reads_in_chunks() -> None: |
| 442 | + """Test TextIOPayload reads data in chunks of READ_SIZE, not all at once.""" |
| 443 | + # Create a large text file that's multiple times larger than READ_SIZE |
| 444 | + large_text = "x" * (READ_SIZE * 3 + 1000) # ~192KB + 1000 chars |
| 445 | + |
| 446 | + # Mock the file-like object to track read calls |
| 447 | + mock_file = unittest.mock.Mock(spec=io.StringIO) |
| 448 | + mock_file.tell.return_value = 0 |
| 449 | + mock_file.fileno.side_effect = AttributeError # Make size return None |
| 450 | + mock_file.encoding = "utf-8" |
| 451 | + |
| 452 | + # Track the sizes of read() calls |
| 453 | + read_sizes = [] |
| 454 | + |
| 455 | + def mock_read(size: int) -> str: |
| 456 | + read_sizes.append(size) |
| 457 | + # Return data based on how many times read was called |
| 458 | + call_count = len(read_sizes) |
| 459 | + if call_count == 1: |
| 460 | + return large_text[:size] |
| 461 | + elif call_count == 2: |
| 462 | + return large_text[READ_SIZE : READ_SIZE + size] |
| 463 | + elif call_count == 3: |
| 464 | + return large_text[READ_SIZE * 2 : READ_SIZE * 2 + size] |
| 465 | + else: |
| 466 | + return large_text[READ_SIZE * 3 :] |
| 467 | + |
| 468 | + mock_file.read.side_effect = mock_read |
| 469 | + |
| 470 | + payload_obj = payload.TextIOPayload(mock_file) |
| 471 | + writer = MockStreamWriter() |
| 472 | + |
| 473 | + # Write with a large content_length |
| 474 | + await payload_obj.write_with_length(writer, len(large_text.encode("utf-8"))) |
| 475 | + |
| 476 | + # Verify that reads were limited to READ_SIZE |
| 477 | + assert len(read_sizes) > 1 # Should have multiple reads |
| 478 | + for read_size in read_sizes: |
| 479 | + assert ( |
| 480 | + read_size <= READ_SIZE |
| 481 | + ), f"Read size {read_size} exceeds READ_SIZE {READ_SIZE}" |
| 482 | + |
| 483 | + |
| 484 | +async def test_textio_payload_large_content_length() -> None: |
| 485 | + """Test TextIOPayload with very large content_length doesn't read all at once.""" |
| 486 | + text_data = "x" * (READ_SIZE + 1000) |
| 487 | + |
| 488 | + # Create a custom file-like object that tracks read sizes |
| 489 | + class TrackingStringIO(io.StringIO): |
| 490 | + def __init__(self, data: str) -> None: |
| 491 | + super().__init__(data) |
| 492 | + self.read_sizes: List[int] = [] |
| 493 | + |
| 494 | + def read(self, size: Optional[int] = -1) -> str: |
| 495 | + self.read_sizes.append(size if size is not None else -1) |
| 496 | + return super().read(size) |
| 497 | + |
| 498 | + tracking_file = TrackingStringIO(text_data) |
| 499 | + payload_obj = payload.TextIOPayload(tracking_file) |
| 500 | + writer = MockStreamWriter() |
| 501 | + |
| 502 | + # Write with a very large content_length (simulating the bug scenario) |
| 503 | + large_content_length = 10 * 1024 * 1024 # 10MB |
| 504 | + await payload_obj.write_with_length(writer, large_content_length) |
| 505 | + |
| 506 | + # Verify no single read exceeded READ_SIZE |
| 507 | + for read_size in tracking_file.read_sizes: |
| 508 | + assert ( |
| 509 | + read_size <= READ_SIZE |
| 510 | + ), f"Read size {read_size} exceeds READ_SIZE {READ_SIZE}" |
| 511 | + |
| 512 | + # Verify the correct amount of data was written |
| 513 | + assert writer.get_written_bytes() == text_data.encode("utf-8") |
| 514 | + |
| 515 | + |
366 | 516 | async def test_async_iterable_payload_write_with_length_no_limit() -> None:
|
367 | 517 | """Test AsyncIterablePayload writing with no content length limit."""
|
368 | 518 |
|
|
0 commit comments