diff --git a/aiosmtpd/docs/NEWS.rst b/aiosmtpd/docs/NEWS.rst index c2e51f30..e4af422b 100644 --- a/aiosmtpd/docs/NEWS.rst +++ b/aiosmtpd/docs/NEWS.rst @@ -12,7 +12,12 @@ Fixed/Improved * Dropped Python 3.8, PyPy 3.8 * Added PyPy 3.9 +* buffering improvements to ``smtp_DATA()`` (should improve/fix #293) +Added +----- + +* ``handle_DATA_CHUNK()`` hook: chunked data receiving to avoid buffering the entire message in memory while it is being received 1.4.6 (2024-05-18) ================== diff --git a/aiosmtpd/docs/handlers.rst b/aiosmtpd/docs/handlers.rst index 3bd37f79..6b89bbca 100644 --- a/aiosmtpd/docs/handlers.rst +++ b/aiosmtpd/docs/handlers.rst @@ -98,6 +98,35 @@ The following hooks are currently supported (in alphabetical order): ``decode_data=False`` or ``decode_data=True``. See :attr:`Envelope.content` for more info. +.. py:method:: handle_DATA_CHUNK(server, session, envelope, data: bytes, text: Optional[str], last: bool) -> Optional[str] + :async: + + :return: Response message to be sent to the client + + Alternative to handle_DATA(), under active development, subject to change. + + Called periodically throughout ``DATA`` as the message (`"SMTP content" + `_ as described in + RFC 5321) is received. + + The content is passed to ``data`` as type ``bytes``, + normalized according to the transparency rules + as defined in :rfc:`RFC 5321, §4.5.2 <5321#section-4.5.2>`. + + If :class:`~aiosmtpd.smtp.SMTP` was instantiated with + ``decode_data=True``, the decoded text will be passed to ``text`` + as a python string. + + ``last`` will be ``False`` prior to the final call. The handler MAY + return a non-``None`` response prior to the ``last=True`` + call. This is treated as an error and terminates the + transaction. The handler will not be invoked again after a + non-``None`` response. Otherwise, the handler MUST return a + non-``None`` response to the ``last=True`` call. + + :class:`~aiosmtpd.smtp.SMTP` buffers data per the ``chunk_size`` + parameter. The hook may be invoked with an empty chunk at any time. + .. py:method:: handle_EHLO(server, session, envelope, hostname, responses) -> List[str] :async: :noindex: diff --git a/aiosmtpd/smtp.py b/aiosmtpd/smtp.py index 3b85fcfb..582310c5 100644 --- a/aiosmtpd/smtp.py +++ b/aiosmtpd/smtp.py @@ -32,6 +32,7 @@ Union, ) from warnings import warn +from io import BytesIO import attr from public import public @@ -330,7 +331,8 @@ def __init__( command_call_limit: Union[int, Dict[str, int], None] = None, authenticator: Optional[AuthenticatorType] = None, proxy_protocol_timeout: Optional[Union[int, float]] = None, - loop: Optional[asyncio.AbstractEventLoop] = None + loop: Optional[asyncio.AbstractEventLoop] = None, + chunk_size: int = 2**16, ): self.__ident__ = ident or __ident__ self.loop = loop if loop else make_loop() @@ -373,6 +375,7 @@ def __init__( "can lead to security vulnerabilities!") log.warning("auth_required == True but auth_require_tls == False") self._auth_require_tls = auth_require_tls + self._chunk_size = chunk_size if proxy_protocol_timeout is not None: if proxy_protocol_timeout <= 0: @@ -1404,6 +1407,22 @@ async def smtp_RSET(self, arg: str): status = await self._call_handler_hook('RSET') await self.push('250 OK' if status is MISSING else status) + # -> err, decoded data + def _decode_line(self, data: bytes + ) -> Tuple[Union[_Missing, bytes], Optional[str]]: + if not self._decode_data: + return MISSING, None + if self.enable_SMTPUTF8: + return MISSING, data.decode('utf-8', errors='surrogateescape') + else: + try: + return MISSING, data.decode('ascii', errors='strict') + except UnicodeDecodeError: + # This happens if enable_smtputf8 is false, meaning that + # the server explicitly does not want to accept non-ascii, + # but the client ignores that and sends non-ascii anyway. + return b'500 Error: strict ASCII mode', None + @syntax('DATA') async def smtp_DATA(self, arg: str) -> None: if await self.check_helo_needed(): @@ -1418,13 +1437,24 @@ async def smtp_DATA(self, arg: str) -> None: await self.push('501 Syntax: DATA') return - await self.push('354 End data with .') - data: List[bytearray] = [] + chunking = "DATA_CHUNK" in self._handle_hooks + + data_status = '354 End data with .' + if chunking and ((s := await self._call_handler_hook( + 'DATA_CHUNK', b'', + '' if self._decode_data else None, False)) is not None): + data_status = s + await self.push(data_status) + if data_status[0:3] != '354': + return + data = BytesIO() num_bytes: int = 0 limit: Optional[int] = self.data_size_limit - line_fragments: List[bytes] = [] state: _DataState = _DataState.NOMINAL + status : Union[_Missing, bytes] = MISSING + DOT = ord('.') + while self.transport is not None: # pragma: nobranch # Since eof_received cancels this coroutine, # readuntil() can never raise asyncio.IncompleteReadError. @@ -1442,16 +1472,23 @@ async def smtp_DATA(self, arg: str) -> None: # The line exceeds StreamReader's "stream limit". # Delay SMTP Status Code sending until data receive is complete # This seems to be implied in RFC 5321 § 4.2.5 + + # TODO this (and _handle_client()) will currently read + # an unbounded amount of data from the client looking + # for crlf. Possibly this should return an immediate + # error and close the connection after some limit ~16kb. if state == _DataState.NOMINAL: # Transition to TOO_LONG only if we haven't gone TOO_MUCH yet state = _DataState.TOO_LONG # Discard data immediately to prevent memory pressure - data *= 0 + data.truncate(0) # Drain the stream anyways line = await self._reader.read(e.consumed) assert not line.endswith(b'\r\n') + continue + # A lone dot in a line signals the end of DATA. - if not line_fragments and line == b'.\r\n': + if line == b'.\r\n': break num_bytes += len(line) if state == _DataState.NOMINAL and limit and num_bytes > limit: @@ -1459,62 +1496,85 @@ async def smtp_DATA(self, arg: str) -> None: # This seems to be implied in RFC 5321 § 4.2.5 state = _DataState.TOO_MUCH # Discard data immediately to prevent memory pressure - data *= 0 - line_fragments.append(line) - if line.endswith(b'\r\n'): - # Record data only if state is "NOMINAL" - if state == _DataState.NOMINAL: - line = EMPTY_BARR.join(line_fragments) - if len(line) > self.line_length_limit: - # Theoretically we shouldn't reach this place. But it's always - # good to practice DEFENSIVE coding. - state = _DataState.TOO_LONG - # Discard data immediately to prevent memory pressure - data *= 0 - else: - data.append(EMPTY_BARR.join(line_fragments)) - line_fragments *= 0 + data.truncate(0) + assert line.endswith(b'\r\n') + assert len(line) <= (self.line_length_limit + 2) + # Record data only if state is "NOMINAL" + if state != _DataState.NOMINAL: + continue + + # Remove extraneous carriage returns and de-transparency + # according to RFC 5321, Section 4.5.2. + if line[0] == DOT: + line = line[1:] + + if not chunking: + data.write(line) + continue + + # reset timeout on every read so it can take as long as it + # takes as long as the client keeps making forward + # progress + # TODO we could do this for !chunking under control of a + # flag so as not to change the behavior unexpectedly? + self._reset_timeout() + + if data.tell() + len(line) > self._chunk_size: + data.seek(0) + chunk = data.read() + data.truncate(0) + data.seek(0) + decoded_line = None + if status is MISSING: + status, decoded_line = self._decode_line(chunk) + if status is MISSING and ( + (s := await self._call_handler_hook( + 'DATA_CHUNK', chunk, decoded_line, False)) is not None): + status = s + data.write(line) # Day of reckoning! Let's take care of those out-of-nominal situations - if state != _DataState.NOMINAL: + if state != _DataState.NOMINAL or status is not MISSING: if state == _DataState.TOO_LONG: await self.push("500 Line too long (see RFC5321 4.5.3.1.6)") elif state == _DataState.TOO_MUCH: # pragma: nobranch await self.push('552 Error: Too much mail data') + elif status is not MISSING: + await self.push(status) self._set_post_data_state() return - # If unfinished_line is non-empty, then the connection was closed. - assert not line_fragments - - # Remove extraneous carriage returns and de-transparency - # according to RFC 5321, Section 4.5.2. - for text in data: - if text.startswith(b'.'): - del text[0] - original_content: bytes = EMPTYBYTES.join(data) + data.seek(0) + original_content: bytes = data.read() # Discard data immediately to prevent memory pressure - data *= 0 + data.truncate(0) - content: Union[str, bytes] + content: Union[str, bytes, None] = None if self._decode_data: - if self.enable_SMTPUTF8: - content = original_content.decode('utf-8', errors='surrogateescape') - else: - try: - content = original_content.decode('ascii', errors='strict') - except UnicodeDecodeError: - # This happens if enable_smtputf8 is false, meaning that - # the server explicitly does not want to accept non-ascii, - # but the client ignores that and sends non-ascii anyway. - await self.push('500 Error: strict ASCII mode') - return - else: + status, content = self._decode_line(original_content) + if status is not MISSING: + await self.push(status) + return + + # Call the new API first if it's implemented. + if chunking: + assert status is MISSING # handled above + new_status = await self._call_handler_hook( + 'DATA_CHUNK', original_content, content, True) + assert new_status is not None + status = new_status + assert status is not MISSING + self._set_post_data_state() + await self.push(status) + return + + if not self._decode_data: content = original_content + assert content is not None + self.envelope.content = content self.envelope.original_content = original_content - # Call the new API first if it's implemented. if "DATA" in self._handle_hooks: status = await self._call_handler_hook('DATA') else: @@ -1526,18 +1586,20 @@ async def smtp_DATA(self, arg: str) -> None: assert self.session is not None args = (self.session.peer, self.envelope.mail_from, self.envelope.rcpt_tos, self.envelope.content) + old_status : Optional[bytes] = None if asyncio.iscoroutinefunction( self.event_handler.process_message): - status = await self.event_handler.process_message(*args) + old_status = await self.event_handler.process_message(*args) else: - status = self.event_handler.process_message(*args) + old_status = self.event_handler.process_message(*args) # The deprecated API can return None which means, return the - # default status. Don't worry about coverage for this case as - # it's a deprecated API that will go away after 1.0. - if status is None: # pragma: nocover + # default status. + if old_status is None: status = MISSING + else: + status = old_status self._set_post_data_state() - await self.push('250 OK' if status is MISSING else status) + await self.push(b'250 OK' if status is MISSING else status) # Commands that have not been implemented. async def smtp_EXPN(self, arg: str): diff --git a/aiosmtpd/testing/helpers.py b/aiosmtpd/testing/helpers.py index 8ee630a7..d8ff6ae0 100644 --- a/aiosmtpd/testing/helpers.py +++ b/aiosmtpd/testing/helpers.py @@ -10,7 +10,7 @@ import sys import time from smtplib import SMTP as SMTP_Client -from typing import List +from typing import List, Optional from aiosmtpd.smtp import Envelope, Session, SMTP @@ -58,6 +58,39 @@ async def handle_DATA( return "250 OK" +class ChunkedReceivingHandler: + def __init__(self): + self.box: List[Envelope] = [] + self.responses: List[Optional[str]] = [None, '250 OK'] + self.sent_response = False + + async def handle_DATA_CHUNK( + self, server: SMTP, session: Session, envelope: Envelope, + data: bytes, text: Optional[str], last: bool, + ) -> Optional[str]: + assert not self.sent_response + if text is not None: + if envelope.content is None: + envelope.content = '' + assert isinstance(envelope.content, str) + envelope.content += text + if envelope.original_content is None: + envelope.original_content = b'' + envelope.original_content += data + else: + if envelope.content is None: + envelope.content = b'' + assert isinstance(envelope.content, bytes) + envelope.content += data + + if last: + self.box.append(envelope) + resp = self.responses.pop(0) + if resp is not None: + self.sent_response = True + return resp + + def catchup_delay(delay: float = ASYNCIO_CATCHUP_DELAY): """ Sleep for awhile to give asyncio's event loop time to catch up. diff --git a/aiosmtpd/tests/test_handlers.py b/aiosmtpd/tests/test_handlers.py index 392689db..f4c2618e 100644 --- a/aiosmtpd/tests/test_handlers.py +++ b/aiosmtpd/tests/test_handlers.py @@ -11,7 +11,7 @@ from smtplib import SMTPDataError, SMTPRecipientsRefused from textwrap import dedent from types import SimpleNamespace -from typing import AnyStr, Callable, Generator, Type, TypeVar, Union +from typing import AnyStr, Callable, Generator, Optional, Type, TypeVar, Union import pytest @@ -187,8 +187,11 @@ def factory(self): class DeprecatedHandler: + def __init__(self): + self.response: Optional[str] = None + def process_message(self, peer, mailfrom, rcpttos, data, **kws): - pass + return self.response class AsyncDeprecatedHandler: @@ -986,11 +989,20 @@ def _process_message_testing(self, controller, client): ), ) + @handler_data(class_=DeprecatedHandler) + def test_process_message_no_response(self, plain_controller, client): + """handler.process_message is Deprecated""" + handler = plain_controller.handler + assert isinstance(handler, DeprecatedHandler) + controller = plain_controller + self._process_message_testing(controller, client) + @handler_data(class_=DeprecatedHandler) def test_process_message(self, plain_controller, client): """handler.process_message is Deprecated""" handler = plain_controller.handler assert isinstance(handler, DeprecatedHandler) + handler.response = '250 ok' controller = plain_controller self._process_message_testing(controller, client) diff --git a/aiosmtpd/tests/test_smtp.py b/aiosmtpd/tests/test_smtp.py index 008ba8cd..3b2db4f1 100644 --- a/aiosmtpd/tests/test_smtp.py +++ b/aiosmtpd/tests/test_smtp.py @@ -42,6 +42,7 @@ auth_mechanism, ) from aiosmtpd.testing.helpers import ( + ChunkedReceivingHandler, ReceivingHandler, catchup_delay, reset_connection, @@ -1603,23 +1604,6 @@ def test_long_line_double_count(self, plain_controller, client): client.sendmail("anne@example.com", ["bart@example.com"], mail) assert exc.value.args == S.S500_DATALINE_TOO_LONG - def test_long_line_leak(self, mocker: MockFixture, plain_controller, client): - # Simulates situation where readuntil() does not raise LimitOverrunError, - # but somehow the line_fragments when join()ed resulted in a too-long line - - # Hijack EMPTY_BARR.join() to return a bytes object that's definitely too long - mock_ebarr = mocker.patch("aiosmtpd.smtp.EMPTY_BARR") - mock_ebarr.join.return_value = b"a" * 1010 - - client.helo("example.com") - mail = "z" * 72 # Make sure this is small and definitely within limits - with pytest.raises(SMTPDataError) as exc: - client.sendmail("anne@example.com", ["bart@example.com"], mail) - assert exc.value.args == S.S500_DATALINE_TOO_LONG - # self.assertEqual(cm.exception.smtp_code, 500) - # self.assertEqual(cm.exception.smtp_error, - # b'Line too long (see RFC5321 4.5.3.1.6)') - @controller_data(data_size_limit=20) def test_too_long_body_delay_error(self, plain_controller): with socket.socket() as sock: @@ -1677,6 +1661,114 @@ def test_too_long_lines_then_too_long_body(self, plain_controller, client): client.sendmail("anne@example.com", ["bart@example.com"], mail) assert exc.value.args == S.S500_DATALINE_TOO_LONG + @controller_data(decode_data=True) + @handler_data(class_=ChunkedReceivingHandler) + def test_chunked_receiving(self, plain_controller, client): + smtpd: Server = plain_controller.smtpd + smtpd._chunk_size = 10 + handler = plain_controller.handler + handler.responses = [None, None, '250 OK'] + self._ehlo(client) + client.send(b'MAIL FROM:\r\n') + assert client.getreply() == S.S250_OK + client.send(b'RCPT TO:\r\n') + assert client.getreply() == S.S250_OK + client.send(b'DATA\r\n') + assert client.getreply() == S.S354_DATA_ENDWITH + client.send(b'hello, \r\n') # fits in chunk_size + client.send(b'\xe4\xb8\x96\xe7\x95\x8c!\r\n') # overflow -> flush + client.send(b'.\r\n') + assert client.getreply() == S.S250_OK + + assert len(handler.box) == 1 + envelope = handler.box[0] + assert envelope.original_content == b'hello, \r\n\xe4\xb8\x96\xe7\x95\x8c!\r\n' + assert envelope.content == 'hello, \r\n世界!\r\n' + + @controller_data(decode_data=False) + @handler_data(class_=ChunkedReceivingHandler) + def test_chunked_receiving_no_decode(self, plain_controller, client): + smtpd: Server = plain_controller.smtpd + smtpd._chunk_size = 10 + handler = plain_controller.handler + handler.responses = [None, None, '250 OK'] + self._ehlo(client) + client.send(b'MAIL FROM:\r\n') + assert client.getreply() == S.S250_OK + client.send(b'RCPT TO:\r\n') + assert client.getreply() == S.S250_OK + client.send(b'DATA\r\n') + assert client.getreply() == S.S354_DATA_ENDWITH + client.send(b'hello, \r\n') + client.send(b'\xe4\xb8\x96\xe7\x95\x8c!\r\n') + client.send(b'.\r\n') + assert client.getreply() == S.S250_OK + + assert len(handler.box) == 1 + envelope = handler.box[0] + assert envelope.content == b'hello, \r\n\xe4\xb8\x96\xe7\x95\x8c!\r\n' + assert envelope.original_content is None + + @controller_data(decode_data=True) + @handler_data(class_=ChunkedReceivingHandler) + def test_chunked_receiving_data_response_err(self, plain_controller, client): + smtpd: Server = plain_controller.smtpd + smtpd._chunk_size = 10 + handler = plain_controller.handler + handler.responses = ['550 bad'] + self._ehlo(client) + client.send(b'MAIL FROM:\r\n') + assert client.getreply() == S.S250_OK + client.send(b'RCPT TO:\r\n') + assert client.getreply() == S.S250_OK + client.send(b'DATA\r\n') + assert client.getreply() == (550, b'bad') + assert len(handler.box) == 0 + + @controller_data(decode_data=True) + @handler_data(class_=ChunkedReceivingHandler) + def test_chunked_receiving_non_last_err(self, plain_controller, client): + smtpd: Server = plain_controller.smtpd + smtpd._chunk_size = 10 + handler = plain_controller.handler + handler.responses = [None, '550 bad'] + self._ehlo(client) + client.send(b'MAIL FROM:\r\n') + assert client.getreply() == S.S250_OK + client.send(b'RCPT TO:\r\n') + assert client.getreply() == S.S250_OK + client.send(b'DATA\r\n') + assert client.getreply() == S.S354_DATA_ENDWITH + client.send(b'hello, \r\n') # fits in chunk_size + client.send(b'\xe4\xb8\x96\xe7\x95\x8c!\r\n') # overflow -> flush + client.send(b'more data\r\n') + client.send(b'.\r\n') + assert client.getreply() == (550, b'bad') + + assert len(handler.box) == 0 + + @controller_data(decode_data=True, enable_SMTPUTF8=False) + @handler_data(class_=ChunkedReceivingHandler) + def test_chunked_receiving_decode_err(self, plain_controller, client): + smtpd: Server = plain_controller.smtpd + smtpd._chunk_size = 10 + handler = plain_controller.handler + handler.responses = [None, None] + self._ehlo(client) + client.send(b'MAIL FROM:\r\n') + assert client.getreply() == S.S250_OK + client.send(b'RCPT TO:\r\n') + assert client.getreply() == S.S250_OK + client.send(b'DATA\r\n') + assert client.getreply() == S.S354_DATA_ENDWITH + client.send(b'hello, \r\n') # fits in chunk_size + client.send(b'\xe4\xb8\x96\xe7\x95\x8c!\r\n') # overflow -> flush + client.send(b'more data\r\n') + client.send(b'.\r\n') + assert client.getreply() == S.S500_STRICT_ASCII + + assert len(handler.box) == 0 + class TestCustomization(_CommonMethods): @controller_data(class_=CustomHostnameController)