diff --git a/Makefile b/Makefile index c4738cd..2d94c53 100644 --- a/Makefile +++ b/Makefile @@ -31,7 +31,9 @@ test: install .PHONY: lint lint: #! Run type analysis and linting checks lint: install - poetry run mypy ld_eventsource + @poetry run mypy ld_eventsource + @poetry run isort --check --atomic ld_eventsource contract-tests + @poetry run pycodestyle ld_eventsource contract-tests # # Documentation generation diff --git a/contract-tests/service.py b/contract-tests/service.py index e29b9b3..3fc173d 100644 --- a/contract-tests/service.py +++ b/contract-tests/service.py @@ -1,13 +1,12 @@ -from stream_entity import StreamEntity - import json import logging import os import sys -import urllib3 +from logging.config import dictConfig + from flask import Flask, request from flask.logging import default_handler -from logging.config import dictConfig +from stream_entity import StreamEntity default_port = 8000 @@ -30,7 +29,7 @@ 'handlers': ['console'] }, 'loggers': { - 'werkzeug': { 'level': 'ERROR' } # disable irrelevant Flask app logging + 'werkzeug': {'level': 'ERROR'} # disable irrelevant Flask app logging } }) @@ -54,11 +53,13 @@ def status(): } return (json.dumps(body), 200, {'Content-type': 'application/json'}) + @app.route('/', methods=['DELETE']) def delete_stop_service(): global_log.info("Test service has told us to exit") os._exit(0) + @app.route('/', methods=['POST']) def post_create_stream(): global stream_counter, streams @@ -74,6 +75,7 @@ def post_create_stream(): return ('', 201, {'Location': resource_url}) + @app.route('/streams/', methods=['POST']) def post_stream_command(id): global streams @@ -87,6 +89,7 @@ def post_stream_command(id): return ('', 400) return ('', 204) + @app.route('/streams/', methods=['DELETE']) def delete_stream(id): global streams @@ -97,6 +100,7 @@ def delete_stream(id): stream.close() return ('', 204) + if __name__ == "__main__": port = default_port if sys.argv[len(sys.argv) - 1] != 'service.py': diff --git a/contract-tests/stream_entity.py b/contract-tests/stream_entity.py index 70b7d31..01ab52e 100644 --- a/contract-tests/stream_entity.py +++ b/contract-tests/stream_entity.py @@ -4,17 +4,18 @@ import sys import threading import traceback + import urllib3 # Import ld_eventsource from parent directory sys.path.insert(1, os.path.join(sys.path[0], '..')) -from ld_eventsource import * -from ld_eventsource.actions import * -from ld_eventsource.config import * - +from ld_eventsource import * # noqa: E402 +from ld_eventsource.actions import * # noqa: E402 +from ld_eventsource.config import * # noqa: E402 http_client = urllib3.PoolManager() + def millis_to_seconds(t): return None if t is None else t / 1000 @@ -27,7 +28,7 @@ def __init__(self, options): self.closed = False self.callback_counter = 0 self.sse = None - + thread = threading.Thread(target=self.run) thread.start() @@ -38,36 +39,47 @@ def run(self): connect = ConnectStrategy.http( url=stream_url, headers=self.options.get("headers"), - urllib3_request_options=None if self.options.get("readTimeoutMs") is None else { - "timeout": urllib3.Timeout(read=millis_to_seconds(self.options.get("readTimeoutMs"))) - } - ) + urllib3_request_options=( + None + if self.options.get("readTimeoutMs") is None + else { + "timeout": urllib3.Timeout( + read=millis_to_seconds(self.options.get("readTimeoutMs")) + ) + } + ), + ) sse = SSEClient( connect, - initial_retry_delay=millis_to_seconds(self.options.get("initialDelayMs")), + initial_retry_delay=millis_to_seconds( + self.options.get("initialDelayMs") + ), last_event_id=self.options.get("lastEventId"), - error_strategy=ErrorStrategy.from_lambda(lambda _: - (ErrorStrategy.FAIL if self.closed else ErrorStrategy.CONTINUE, None)), - logger=self.log + error_strategy=ErrorStrategy.from_lambda( + lambda _: ( + ErrorStrategy.FAIL if self.closed else ErrorStrategy.CONTINUE, + None, + ) + ), + logger=self.log, ) self.sse = sse for item in sse.all: if isinstance(item, Event): self.log.info('Received event from stream (%s)', item.event) - self.send_message({ - 'kind': 'event', - 'event': { - 'type': item.event, - 'data': item.data, - 'id': item.last_event_id + self.send_message( + { + 'kind': 'event', + 'event': { + 'type': item.event, + 'data': item.data, + 'id': item.last_event_id, + }, } - }) + ) elif isinstance(item, Comment): self.log.info('Received comment from stream: %s', item.comment) - self.send_message({ - 'kind': 'comment', - 'comment': item.comment - }) + self.send_message({'kind': 'comment', 'comment': item.comment}) elif isinstance(item, Fault): if self.closed: break @@ -75,23 +87,17 @@ def run(self): # Currently the test harness does not expect us to send an error message in that case. if item.error: self.log.info('Received error from stream: %s' % item.error) - self.send_message({ - 'kind': 'error', - 'error': str(item.error) - }) + self.send_message({'kind': 'error', 'error': str(item.error)}) except Exception as e: self.log.info('Received error from stream: %s', e) self.log.info(traceback.format_exc()) - self.send_message({ - 'kind': 'error', - 'error': str(e) - }) + self.send_message({'kind': 'error', 'error': str(e)}) def do_command(self, command: str) -> bool: self.log.info('Test service sent command: %s' % command) # currently we support no special commands return False - + def send_message(self, message): global http_client @@ -104,9 +110,9 @@ def send_message(self, message): resp = http_client.request( 'POST', callback_url, - headers = {'Content-Type': 'application/json'}, - body = json.dumps(message) - ) + headers={'Content-Type': 'application/json'}, + body=json.dumps(message), + ) if resp.status >= 300 and not self.closed: self.log.error('Callback request returned HTTP error %d', resp.status) except Exception as e: diff --git a/ld_eventsource/actions.py b/ld_eventsource/actions.py index 7216e36..276eec1 100644 --- a/ld_eventsource/actions.py +++ b/ld_eventsource/actions.py @@ -6,6 +6,7 @@ class Action: """ Base class for objects that can be returned by :attr:`.SSEClient.all`. """ + pass @@ -16,11 +17,13 @@ class Event(Action): Instances of this class are returned by both :attr:`.SSEClient.events` and :attr:`.SSEClient.all`. """ - def __init__(self, - event: str='message', - data: str='', - id: Optional[str]=None, - last_event_id: Optional[str]=None + + def __init__( + self, + event: str = 'message', + data: str = '', + id: Optional[str] = None, + last_event_id: Optional[str] = None, ): self._event = event self._data = data @@ -58,27 +61,31 @@ def last_event_id(self) -> Optional[str]: def __eq__(self, other): if not isinstance(other, Event): return False - return self._event == other._event and self._data == other._data \ - and self._id == other._id and self.last_event_id == other.last_event_id + return ( + self._event == other._event + and self._data == other._data + and self._id == other._id + and self.last_event_id == other.last_event_id + ) def __repr__(self): return "Event(event=\"%s\", data=%s, id=%s, last_event_id=%s)" % ( self._event, json.dumps(self._data), "None" if self._id is None else json.dumps(self._id), - "None" if self._last_event_id is None else json.dumps(self._last_event_id) + "None" if self._last_event_id is None else json.dumps(self._last_event_id), ) class Comment(Action): """ A comment received by :class:`.SSEClient`. - + Comment lines (any line beginning with a colon) have no significance in the SSE specification and can be ignored, but if you want to see them, use :attr:`.SSEClient.all`. They will never be returned by :attr:`.SSEClient.events`. """ - + def __init__(self, comment: str): self._comment = comment @@ -104,6 +111,7 @@ class Start(Action): A ``Start`` is returned for the first successful connection. If the client reconnects after a failure, there will be a :class:`.Fault` followed by a ``Start``. """ + pass @@ -121,7 +129,7 @@ class Fault(Action): def __init__(self, error: Optional[Exception]): self.__error = error - + @property def error(self) -> Optional[Exception]: """ diff --git a/ld_eventsource/config/__init__.py b/ld_eventsource/config/__init__.py index 649e19c..8eadfb6 100644 --- a/ld_eventsource/config/__init__.py +++ b/ld_eventsource/config/__init__.py @@ -1,3 +1,4 @@ -from .connect_strategy import ConnectStrategy, ConnectionClient, ConnectionResult +from .connect_strategy import (ConnectionClient, ConnectionResult, + ConnectStrategy) from .error_strategy import ErrorStrategy from .retry_delay_strategy import RetryDelayStrategy diff --git a/ld_eventsource/config/connect_strategy.py b/ld_eventsource/config/connect_strategy.py index 39cb5a0..1b59f67 100644 --- a/ld_eventsource/config/connect_strategy.py +++ b/ld_eventsource/config/connect_strategy.py @@ -1,6 +1,8 @@ from __future__ import annotations + from logging import Logger from typing import Callable, Iterator, Optional, Union + from urllib3 import PoolManager from ld_eventsource.http import _HttpClientImpl, _HttpConnectParams @@ -33,9 +35,9 @@ def create_client(self, logger: Logger) -> ConnectionClient: @staticmethod def http( url: str, - headers: Optional[dict]=None, - pool: Optional[PoolManager]=None, - urllib3_request_options: Optional[dict]=None + headers: Optional[dict] = None, + pool: Optional[PoolManager] = None, + urllib3_request_options: Optional[dict] = None, ) -> ConnectStrategy: """ Creates the default HTTP implementation, specifying request parameters. @@ -46,7 +48,9 @@ def http( :param urllib3_request_options: optional ``kwargs`` to add to the ``request`` call; these can include any parameters supported by ``urllib3``, such as ``timeout`` """ - return _HttpConnectStrategy(_HttpConnectParams(url, headers, pool, urllib3_request_options)) + return _HttpConnectStrategy( + _HttpConnectParams(url, headers, pool, urllib3_request_options) + ) class ConnectionClient: @@ -65,7 +69,9 @@ def connect(self, last_event_id: Optional[str]) -> ConnectionResult: (should be sent to the server to support resuming an interrupted stream) :return: a :class:`ConnectionResult` representing the stream """ - raise NotImplementedError("ConnectionClient base class cannot be used by itself") + raise NotImplementedError( + "ConnectionClient base class cannot be used by itself" + ) def close(self): """ @@ -80,16 +86,12 @@ def __exit__(self, type, value, traceback): self.close() - class ConnectionResult: """ The return type of :meth:`ConnectionClient.connect()`. """ - def __init__( - self, - stream: Iterator[bytes], - closer: Optional[Callable] - ): + + def __init__(self, stream: Iterator[bytes], closer: Optional[Callable]): self.__stream = stream self.__closer = closer @@ -118,6 +120,7 @@ def __exit__(self, type, value, traceback): # _HttpConnectStrategy and _HttpConnectionClient are defined here rather than in http.py to avoid # a circular module reference. + class _HttpConnectStrategy(ConnectStrategy): def __init__(self, params: _HttpConnectParams): self.__params = params diff --git a/ld_eventsource/config/error_strategy.py b/ld_eventsource/config/error_strategy.py index 0f39467..2014d86 100644 --- a/ld_eventsource/config/error_strategy.py +++ b/ld_eventsource/config/error_strategy.py @@ -1,4 +1,5 @@ from __future__ import annotations + import time from typing import Callable, Optional, Tuple @@ -7,7 +8,7 @@ class ErrorStrategy: """ Base class of strategies for determining how SSEClient should handle a stream error or the end of a stream. - + The parameter that SSEClient passes to :meth:`apply()` is either ``None`` if the server ended the stream normally, or an exception. If it is an exception, it could be an I/O exception (failure to connect, broken connection, etc.), or one of the error types defined in this @@ -24,7 +25,7 @@ class ErrorStrategy: With either option, it is still always possible to explicitly reconnect the stream by calling :meth:`.SSEClient.start()` again, or simply by trying to read from :attr:`.SSEClient.events` or :attr:`.SSEClient.all` again. - + Subclasses should be immutable. To implement strategies that behave differently on consecutive retries, the strategy should return a new instance of its own class as the second return value from ``apply``, rather than modifying the state of the existing instance. This makes it easy @@ -73,7 +74,7 @@ def continue_with_max_attempts(max_attempts: int) -> ErrorStrategy: :param max_attempts: the maximum number of consecutive retries """ return _MaxAttemptsErrorStrategy(max_attempts, 0) - + @staticmethod def continue_with_time_limit(max_time: float) -> ErrorStrategy: """ @@ -85,11 +86,13 @@ def continue_with_time_limit(max_time: float) -> ErrorStrategy: return _TimeLimitErrorStrategy(max_time, 0) @staticmethod - def from_lambda(fn: Callable[[Optional[Exception]], Tuple[bool, Optional[ErrorStrategy]]]) -> ErrorStrategy: + def from_lambda( + fn: Callable[[Optional[Exception]], Tuple[bool, Optional[ErrorStrategy]]] + ) -> ErrorStrategy: """ Convenience method for creating an ErrorStrategy whose ``apply`` method is equivalent to the given lambda. - + The one difference is that the second return value is an ``Optional[ErrorStrategy]`` which can be None to mean "no change", since the lambda cannot reference the strategy's ``self``. """ @@ -97,31 +100,41 @@ def from_lambda(fn: Callable[[Optional[Exception]], Tuple[bool, Optional[ErrorSt class _LambdaErrorStrategy(ErrorStrategy): - def __init__(self, fn: Callable[[Optional[Exception]], Tuple[bool, Optional[ErrorStrategy]]]): + def __init__( + self, fn: Callable[[Optional[Exception]], Tuple[bool, Optional[ErrorStrategy]]] + ): self.__fn = fn - + def apply(self, exception: Optional[Exception]) -> Tuple[bool, ErrorStrategy]: should_raise, maybe_next = self.__fn(exception) return (should_raise, maybe_next or self) + class _MaxAttemptsErrorStrategy(ErrorStrategy): def __init__(self, max_attempts: int, counter: int): self.__max_attempts = max_attempts self.__counter = counter - + def apply(self, exception: Optional[Exception]) -> Tuple[bool, ErrorStrategy]: if self.__counter >= self.__max_attempts: return (ErrorStrategy.FAIL, self) - return (ErrorStrategy.CONTINUE, _MaxAttemptsErrorStrategy(self.__max_attempts, self.__counter + 1)) + return ( + ErrorStrategy.CONTINUE, + _MaxAttemptsErrorStrategy(self.__max_attempts, self.__counter + 1), + ) + class _TimeLimitErrorStrategy(ErrorStrategy): def __init__(self, max_time: float, start_time: float): self.__max_time = max_time self.__start_time = start_time - + def apply(self, exception: Optional[Exception]) -> Tuple[bool, ErrorStrategy]: if self.__start_time == 0: - return (ErrorStrategy.CONTINUE, _TimeLimitErrorStrategy(self.__max_time, time.time())) + return ( + ErrorStrategy.CONTINUE, + _TimeLimitErrorStrategy(self.__max_time, time.time()), + ) if (time.time() - self.__start_time) < self.__max_time: return (ErrorStrategy.CONTINUE, self) return (ErrorStrategy.FAIL, self) diff --git a/ld_eventsource/config/retry_delay_strategy.py b/ld_eventsource/config/retry_delay_strategy.py index 782efcf..89c48a0 100644 --- a/ld_eventsource/config/retry_delay_strategy.py +++ b/ld_eventsource/config/retry_delay_strategy.py @@ -1,6 +1,7 @@ from __future__ import annotations -from random import Random + import time +from random import Random from typing import Callable, Optional, Tuple @@ -33,7 +34,7 @@ def apply(self, base_delay: float) -> Tuple[float, RetryDelayStrategy]: def default( max_delay: Optional[float] = None, backoff_multiplier: float = 2, - jitter_multiplier: Optional[float] = None + jitter_multiplier: Optional[float] = None, ) -> RetryDelayStrategy: """ Provides the default retry delay behavior for :class:`.SSEClient`, which includes @@ -58,11 +59,18 @@ def default( :param jitter_multiplier: a fraction from 0.0 to 1.0 for how much of the delay may be pseudo-randomly subtracted """ - return _DefaultRetryDelayStrategy(max_delay or 0, backoff_multiplier, jitter_multiplier or 0, - 0, _ReusableRandom(time.time())) + return _DefaultRetryDelayStrategy( + max_delay or 0, + backoff_multiplier, + jitter_multiplier or 0, + 0, + _ReusableRandom(time.time()), + ) @staticmethod - def from_lambda(fn: Callable[[float], Tuple[float, Optional[RetryDelayStrategy]]]) -> RetryDelayStrategy: + def from_lambda( + fn: Callable[[float], Tuple[float, Optional[RetryDelayStrategy]]] + ) -> RetryDelayStrategy: """ Convenience method for creating a RetryDelayStrategy whose ``apply`` method is equivalent to the given lambda. @@ -80,7 +88,7 @@ def __init__( backoff_multiplier: float, jitter_multiplier: float, last_base_delay: float, - random: _ReusableRandom + random: _ReusableRandom, ): self.__max_delay = max_delay self.__backoff_multiplier = backoff_multiplier @@ -89,8 +97,11 @@ def __init__( self.__random = random def apply(self, base_delay: float) -> Tuple[float, RetryDelayStrategy]: - next_base_delay = base_delay if self.__last_base_delay == 0 else \ - self.__last_base_delay * self.__backoff_multiplier + next_base_delay = ( + base_delay + if self.__last_base_delay == 0 + else self.__last_base_delay * self.__backoff_multiplier + ) if self.__max_delay > 0 and next_base_delay > self.__max_delay: next_base_delay = self.__max_delay adjusted_delay = next_base_delay @@ -100,25 +111,31 @@ def apply(self, base_delay: float) -> Tuple[float, RetryDelayStrategy]: # To avoid having this object contain mutable state, we create a new Random with the same # state as our previous Random before using it. random = random.clone() - adjusted_delay -= (random.random() * self.__jitter_multiplier * adjusted_delay) + adjusted_delay -= ( + random.random() * self.__jitter_multiplier * adjusted_delay + ) next_strategy = _DefaultRetryDelayStrategy( self.__max_delay, self.__backoff_multiplier, self.__jitter_multiplier, next_base_delay, - random + random, ) return (adjusted_delay, next_strategy) + class _LambdaRetryDelayStrategy(RetryDelayStrategy): - def __init__(self, fn: Callable[[float], Tuple[float, Optional[RetryDelayStrategy]]]): + def __init__( + self, fn: Callable[[float], Tuple[float, Optional[RetryDelayStrategy]]] + ): self.__fn = fn def apply(self, base_delay: float) -> Tuple[float, RetryDelayStrategy]: delay, maybe_next = self.__fn(base_delay) return (delay, maybe_next or self) + class _ReusableRandom: def __init__(self, seed: float): self.__seed = seed diff --git a/ld_eventsource/errors.py b/ld_eventsource/errors.py index 7b96984..bb5733c 100644 --- a/ld_eventsource/errors.py +++ b/ld_eventsource/errors.py @@ -1,4 +1,3 @@ - class HTTPStatusError(Exception): """ This exception indicates that the client was able to connect to the server, but that @@ -8,11 +7,12 @@ class HTTPStatusError(Exception): def __init__(self, status: int): super().__init__("HTTP error %d" % status) self._status = status - + @property def status(self) -> int: return self._status + class HTTPContentTypeError(Exception): """ This exception indicates that the HTTP response did not have the expected content @@ -22,7 +22,7 @@ class HTTPContentTypeError(Exception): def __init__(self, content_type: str): super().__init__("invalid content type \"%s\"" % content_type) self._content_type = content_type - + @property def content_type(self) -> str: return self._content_type diff --git a/ld_eventsource/http.py b/ld_eventsource/http.py index 8965b26..940446a 100644 --- a/ld_eventsource/http.py +++ b/ld_eventsource/http.py @@ -1,5 +1,6 @@ from logging import Logger from typing import Callable, Iterator, Optional, Tuple + from urllib3 import PoolManager from urllib3.exceptions import MaxRetryError from urllib3.util import Retry @@ -13,9 +14,9 @@ class _HttpConnectParams: def __init__( self, url: str, - headers: Optional[dict]=None, - pool: Optional[PoolManager]=None, - urllib3_request_options: Optional[dict]=None + headers: Optional[dict] = None, + pool: Optional[PoolManager] = None, + urllib3_request_options: Optional[dict] = None, ): self.__url = url self.__headers = headers @@ -25,15 +26,15 @@ def __init__( @property def url(self) -> str: return self.__url - + @property def headers(self) -> Optional[dict]: return self.__headers - + @property def pool(self) -> Optional[PoolManager]: return self.__pool - + @property def urllib3_request_options(self) -> Optional[dict]: return self.__urllib3_request_options @@ -45,10 +46,10 @@ def __init__(self, params: _HttpConnectParams, logger: Logger): self.__pool = params.pool or PoolManager() self.__should_close_pool = params.pool is not None self.__logger = logger - + def connect(self, last_event_id: Optional[str]) -> Tuple[Iterator[bytes], Callable]: self.__logger.info("Connecting to stream at %s" % self.__params.url) - + headers = self.__params.headers.copy() if self.__params.headers else {} headers['Cache-Control'] = 'no-cache' headers['Accept'] = 'text/event-stream' @@ -56,7 +57,11 @@ def connect(self, last_event_id: Optional[str]) -> Tuple[Iterator[bytes], Callab if last_event_id: headers['Last-Event-ID'] = last_event_id - request_options = self.__params.urllib3_request_options.copy() if self.__params.urllib3_request_options else {} + request_options = ( + self.__params.urllib3_request_options.copy() + if self.__params.urllib3_request_options + else {} + ) request_options['headers'] = headers try: @@ -64,16 +69,21 @@ def connect(self, last_event_id: Optional[str]) -> Tuple[Iterator[bytes], Callab 'GET', self.__params.url, preload_content=False, - retries=Retry(total=None, read=0, connect=0, status=0, other=0, redirect=3), - **request_options) + retries=Retry( + total=None, read=0, connect=0, status=0, other=0, redirect=3 + ), + **request_options + ) except MaxRetryError as e: reason: Optional[Exception] = e.reason if reason is not None: - raise reason # e.reason is the underlying I/O error + raise reason # e.reason is the underlying I/O error if resp.status >= 400 or resp.status == 204: raise HTTPStatusError(resp.status) content_type = resp.headers.get('Content-Type', None) - if content_type is None or not str(content_type).startswith("text/event-stream"): + if content_type is None or not str(content_type).startswith( + "text/event-stream" + ): raise HTTPContentTypeError(content_type or '') stream = resp.stream(_CHUNK_SIZE) diff --git a/ld_eventsource/reader.py b/ld_eventsource/reader.py index e608a8e..2f9f515 100644 --- a/ld_eventsource/reader.py +++ b/ld_eventsource/reader.py @@ -1,13 +1,14 @@ -from ld_eventsource.actions import Comment, Event - from typing import Callable, Iterable, Optional +from ld_eventsource.actions import Comment, Event + class _BufferedLineReader: """ Helper class that encapsulates the logic for reading UTF-8 stream data as a series of text lines, each of which can be terminated by \n, \r, or \r\n. """ + @staticmethod def lines_from(chunks): """ @@ -49,21 +50,22 @@ def lines_from(chunks): for line in lines: yield line.decode() + class _SSEReader: def __init__( self, lines_source: Iterable[str], last_event_id: Optional[str] = None, - set_retry: Optional[Callable[[int], None]] = None - ): + set_retry: Optional[Callable[[int], None]] = None, + ): self._lines_source = lines_source self._last_event_id = last_event_id self._set_retry = set_retry - + @property def last_event_id(self): return self._last_event_id - + @property def events_and_comments(self): event_type = "" @@ -78,7 +80,7 @@ def events_and_comments(self): "message" if event_type == "" else event_type, event_data, event_id, - self._last_event_id + self._last_event_id, ) event_type = "" event_data = None @@ -95,11 +97,13 @@ def events_and_comments(self): name = line[:colon_pos] if colon_pos < (len(line) - 1) and line[colon_pos + 1] == ' ': colon_pos += 1 - value = line[colon_pos+1:] + value = line[colon_pos + 1:] if name == 'event': event_type = value elif name == 'data': - event_data = value if event_data is None else (event_data + "\n" + value) + event_data = ( + value if event_data is None else (event_data + "\n" + value) + ) elif name == 'id': if value.find("\x00") < 0: event_id = value diff --git a/ld_eventsource/sse_client.py b/ld_eventsource/sse_client.py index 601c18a..1ac774e 100644 --- a/ld_eventsource/sse_client.py +++ b/ld_eventsource/sse_client.py @@ -1,12 +1,12 @@ -from ld_eventsource.actions import * -from ld_eventsource.errors import * -from ld_eventsource.config import * -from ld_eventsource.reader import _BufferedLineReader, _SSEReader - import logging import time from typing import Iterable, Optional, Union +from ld_eventsource.actions import * +from ld_eventsource.config import * +from ld_eventsource.errors import * +from ld_eventsource.reader import _BufferedLineReader, _SSEReader + class SSEClient: """ @@ -15,7 +15,7 @@ class SSEClient: This is a synchronous implementation which blocks the caller's thread when reading events or reconnecting. It can be run on a worker thread. The expected usage is to create an ``SSEClient`` instance, then read from it using the iterator properties :attr:`events` or :attr:`all`. - + By default, ``SSEClient`` uses ``urllib3`` to make HTTP requests to an SSE endpoint. You can customize this behavior using :class:`.ConnectStrategy`. @@ -25,7 +25,7 @@ class SSEClient: dropped or cannot be made; but if a connection is made and returns an invalid response (non-2xx status, 204 status, or invalid content type), it will not retry. This behavior can be customized with ``error_strategy``. The client will automatically follow 3xx redirects. - + For any non-retryable error, if this is the first connection attempt then the constructor will throw an exception (such as :class:`.HTTPStatusError`). Or, if a successful connection was made so the constructor has already returned, but a @@ -42,14 +42,14 @@ class SSEClient: """ def __init__( - self, + self, connect: Union[str, ConnectStrategy], - initial_retry_delay: float=1, - retry_delay_strategy: Optional[RetryDelayStrategy]=None, - retry_delay_reset_threshold: float=60, - error_strategy: Optional[ErrorStrategy]=None, - last_event_id: Optional[str]=None, - logger: Optional[logging.Logger]=None + initial_retry_delay: float = 1, + retry_delay_strategy: Optional[RetryDelayStrategy] = None, + retry_delay_reset_threshold: float = 60, + error_strategy: Optional[ErrorStrategy] = None, + last_event_id: Optional[str] = None, + logger: Optional[logging.Logger] = None, ): """ Creates a client instance. @@ -89,16 +89,18 @@ def __init__( connect = ConnectStrategy.http(connect) elif not isinstance(connect, ConnectStrategy): raise TypeError("request must be either a string or ConnectStrategy") - + self.__base_retry_delay = initial_retry_delay - self.__base_retry_delay_strategy = retry_delay_strategy or RetryDelayStrategy.default() + self.__base_retry_delay_strategy = ( + retry_delay_strategy or RetryDelayStrategy.default() + ) self.__retry_delay_reset_threshold = retry_delay_reset_threshold self.__current_retry_delay_strategy = self.__base_retry_delay_strategy self.__next_retry_delay = 0 self.__base_error_strategy = error_strategy or ErrorStrategy.always_fail() self.__current_error_strategy = self.__base_error_strategy - + self.__last_event_id = last_event_id if logger is None: @@ -106,7 +108,7 @@ def __init__( logger.addHandler(logging.NullHandler()) logger.propagate = False self.__logger = logger - + self.__connection_client = connect.create_client(logger) self.__connection_result: Optional[ConnectionResult] = None self.__connected_time: float = 0 @@ -142,7 +144,7 @@ def close(self): """ self.__closed = True self.interrupt() - + def interrupt(self): """ Stops the stream connection if it is currently active. @@ -162,7 +164,7 @@ def interrupt(self): def all(self) -> Iterable[Action]: """ An iterable series of notifications from the stream. - + Each of these can be any subclass of :class:`.Action`: :class:`.Event`, :class:`.Comment`, :class:`.Start`, or :class:`.Fault`. @@ -179,7 +181,7 @@ def all(self) -> Iterable[Action]: fault = self._try_start(True) # return either a Start action or a Fault action yield Start() if fault is None else fault - + lines = _BufferedLineReader.lines_from(self.__connection_result.stream) reader = _SSEReader(lines, self.__last_event_id, None) error: Optional[Exception] = None @@ -202,11 +204,15 @@ def all(self) -> Iterable[Action]: # We've hit an error, so ask the ErrorStrategy what to do: raise an exception or yield a Fault. self._compute_next_retry_delay() - fail_or_continue, self.__current_error_strategy = self.__current_error_strategy.apply(error) + fail_or_continue, self.__current_error_strategy = ( + self.__current_error_strategy.apply(error) + ) if fail_or_continue == ErrorStrategy.FAIL: if error is None: # If error is None, the stream was ended normally by the server. Just stop iterating. - yield Fault(None) # this is only visible if you're reading from "all" + yield Fault( + None + ) # this is only visible if you're reading from "all" return raise error yield Fault(error) @@ -232,7 +238,7 @@ def next_retry_delay(self) -> float: """ The retry delay that will be used for the next reconnection, in seconds, if the stream has failed or ended. - + This is initially zero, because SSEClient does not compute a retry delay until there is a failure. If you have just received an exception or a :class:`.Fault`, or if you were iterating through events and the events ran out because the stream closed, the value @@ -240,31 +246,40 @@ def next_retry_delay(self) -> float: is computed by applying the configured :class:`.RetryDelayStrategy` to the base retry delay. """ return self.__next_retry_delay - + def _compute_next_retry_delay(self): if self.__retry_delay_reset_threshold > 0 and self.__connected_time != 0: connection_duration = time.time() - self.__connected_time if connection_duration >= self.__retry_delay_reset_threshold: self.__current_retry_delay_strategy = self.__base_retry_delay_strategy - self.__next_retry_delay, self.__current_retry_delay_strategy = \ + self.__next_retry_delay, self.__current_retry_delay_strategy = ( self.__current_retry_delay_strategy.apply(self.__base_retry_delay) + ) def _try_start(self, can_return_fault: bool) -> Optional[Fault]: if self.__connection_result is not None: return None while True: if self.__next_retry_delay > 0: - delay = self.__next_retry_delay if self.__disconnected_time == 0 else \ - self.__next_retry_delay - (time.time() - self.__disconnected_time) + delay = ( + self.__next_retry_delay + if self.__disconnected_time == 0 + else self.__next_retry_delay + - (time.time() - self.__disconnected_time) + ) if delay > 0: self.__logger.info("Will reconnect after delay of %fs" % delay) time.sleep(delay) try: - self.__connection_result = self.__connection_client.connect(self.__last_event_id) + self.__connection_result = self.__connection_client.connect( + self.__last_event_id + ) except Exception as e: self.__disconnected_time = time.time() self._compute_next_retry_delay() - fail_or_continue, self.__current_error_strategy = self.__current_error_strategy.apply(e) + fail_or_continue, self.__current_error_strategy = ( + self.__current_error_strategy.apply(e) + ) if fail_or_continue == ErrorStrategy.FAIL: raise e if can_return_fault: @@ -287,7 +302,7 @@ def last_event_id(self) -> Optional[str]: depends on the server; it may ignore this value. """ return self.__last_event_id - + def __enter__(self): return self diff --git a/ld_eventsource/testing/helpers.py b/ld_eventsource/testing/helpers.py index 0b18ed5..5647493 100644 --- a/ld_eventsource/testing/helpers.py +++ b/ld_eventsource/testing/helpers.py @@ -1,26 +1,35 @@ from __future__ import annotations +from logging import Logger +from typing import Iterable, Iterator, List, Optional + from ld_eventsource import * from ld_eventsource.config import * from ld_eventsource.errors import * - from ld_eventsource.testing.http_util import * -from logging import Logger -from typing import Iterable, Iterator, List, Optional - def make_stream() -> ChunkedResponse: - return ChunkedResponse({ 'Content-Type': 'text/event-stream' }) + return ChunkedResponse({'Content-Type': 'text/event-stream'}) + def retry_for_status(status: int) -> ErrorStrategy: - return ErrorStrategy.from_lambda(lambda error: \ - (ErrorStrategy.CONTINUE if isinstance(error, HTTPStatusError) and error.status == status \ - else ErrorStrategy.FAIL, None)) + return ErrorStrategy.from_lambda( + lambda error: ( + ( + ErrorStrategy.CONTINUE + if isinstance(error, HTTPStatusError) and error.status == status + else ErrorStrategy.FAIL + ), + None, + ) + ) + def no_delay() -> RetryDelayStrategy: return RetryDelayStrategy.from_lambda(lambda _: (0, None)) + class MockConnectStrategy(ConnectStrategy): def __init__(self, *request_handlers: MockConnectionHandler): self.__handlers = list(request_handlers) @@ -28,6 +37,7 @@ def __init__(self, *request_handlers: MockConnectionHandler): def create_client(self, logger: Logger) -> ConnectionClient: return MockConnectionClient(self.__handlers) + class MockConnectionClient(ConnectionClient): def __init__(self, handlers: List[MockConnectionHandler]): self.__handlers = handlers @@ -39,9 +49,13 @@ def connect(self, last_event_id: Optional[str]) -> ConnectionResult: self.__request_count += 1 return handler.apply() + class MockConnectionHandler: def apply(self) -> ConnectionResult: - raise NotImplementedError("MockConnectionHandler base class cannot be used by itself") + raise NotImplementedError( + "MockConnectionHandler base class cannot be used by itself" + ) + class RejectConnection(MockConnectionHandler): def __init__(self, error: Exception): @@ -50,6 +64,7 @@ def __init__(self, error: Exception): def apply(self) -> ConnectionResult: raise self.__error + class RespondWithStream(MockConnectionHandler): def __init__(self, stream: Iterable[bytes]): self.__stream = stream @@ -57,10 +72,12 @@ def __init__(self, stream: Iterable[bytes]): def apply(self) -> ConnectionResult: return ConnectionResult(stream=self.__stream.__iter__(), closer=None) + class RespondWithData(RespondWithStream): def __init__(self, data: str): super().__init__([bytes(data, 'utf-8')]) + class ExpectNoMoreRequests(MockConnectionHandler): def apply(self) -> ConnectionResult: assert False, "SSEClient should not have made another request" diff --git a/ld_eventsource/testing/http_util.py b/ld_eventsource/testing/http_util.py index a163a49..abcae4c 100644 --- a/ld_eventsource/testing/http_util.py +++ b/ld_eventsource/testing/http_util.py @@ -1,18 +1,20 @@ import json +import queue import socket import ssl -from threading import Thread import time -import queue -from http.server import HTTPServer, BaseHTTPRequestHandler +from http.server import BaseHTTPRequestHandler, HTTPServer +from threading import Thread + def get_available_port(): - s = socket.socket(socket.AF_INET, type = socket.SOCK_STREAM) + s = socket.socket(socket.AF_INET, type=socket.SOCK_STREAM) s.bind(('localhost', 0)) _, port = s.getsockname() s.close() return port + def poll_until_started(port): deadline = time.time() + 1 while time.time() < deadline: @@ -27,18 +29,21 @@ def poll_until_started(port): time.sleep(0.05) raise Exception("test server on port %d was not reachable" % port) + def start_server(): sw = MockServerWrapper(get_available_port(), False) sw.start() poll_until_started(sw.port) return sw + def start_secure_server(): sw = MockServerWrapper(get_available_port(), True) sw.start() poll_until_started(sw.port) return sw + class MockServerWrapper(Thread): def __init__(self, port, secure): Thread.__init__(self) @@ -48,9 +53,9 @@ def __init__(self, port, secure): if secure: self.server.socket = ssl.wrap_socket( self.server.socket, - certfile='./ld_eventsource/testing/selfsigned.pem', # this is a pre-generated self-signed cert that is valid for 100 years + certfile='./ld_eventsource/testing/selfsigned.pem', # this is a pre-generated self-signed cert that is valid for 100 years keyfile='./ld_eventsource/testing/selfsigned.key', - server_side=True + server_side=True, ) self.server.server_wrapper = self self.matchers = {} @@ -94,6 +99,7 @@ def __enter__(self): def __exit__(self, type, value, traceback): self.close() + class MockServerRequestHandler(BaseHTTPRequestHandler): def do_CONNECT(self): self._do_request() @@ -113,6 +119,7 @@ def _do_request(self): else: self.send_error(404) + class MockServerRequest: def __init__(self, request): self.method = request.command @@ -127,8 +134,9 @@ def __init__(self, request): def __str__(self): return "%s %s" % (self.method, self.path) + class BasicResponse: - def __init__(self, status, body = None, headers = None): + def __init__(self, status, body=None, headers=None): self.status = status self.body = body self.headers = headers or {} @@ -145,14 +153,16 @@ def write(self, request): if self.body: request.wfile.write(self.body.encode('UTF-8')) + class JsonResponse(BasicResponse): - def __init__(self, data, headers = None): + def __init__(self, data, headers=None): h = headers or {} - h.update({ 'Content-Type': 'application/json' }) + h.update({'Content-Type': 'application/json'}) BasicResponse.__init__(self, 200, json.dumps(data or {}), h) + class ChunkedResponse: - def __init__(self, headers = None): + def __init__(self, headers=None): self.queue = queue.Queue() self.headers = headers or {} @@ -177,7 +187,9 @@ def write(self, request): request.wfile.flush() break else: - request.wfile.write(('%x\r\n%s\r\n' % (len(chunk), chunk)).encode('UTF-8')) + request.wfile.write( + ('%x\r\n%s\r\n' % (len(chunk), chunk)).encode('UTF-8') + ) request.wfile.flush() def __enter__(self): @@ -186,10 +198,12 @@ def __enter__(self): def __exit__(self, type, value, traceback): self.close() + class CauseNetworkError: def write(self, request): raise Exception('intentional error') + class SequentialHandler: def __init__(self, *argv): self.handlers = argv diff --git a/ld_eventsource/testing/test_error_strategy.py b/ld_eventsource/testing/test_error_strategy.py index a73885c..ed21b2a 100644 --- a/ld_eventsource/testing/test_error_strategy.py +++ b/ld_eventsource/testing/test_error_strategy.py @@ -1,7 +1,6 @@ -from ld_eventsource.config import * - import time +from ld_eventsource.config import * err = Exception("sorry") @@ -13,6 +12,7 @@ def test_always_raise(): assert should_raise is True strategy = next_strategy or strategy + def test_always_continue(): strategy = ErrorStrategy.always_continue() for i in range(100): diff --git a/ld_eventsource/testing/test_http_connect_strategy.py b/ld_eventsource/testing/test_http_connect_strategy.py index 7940ea2..dd50c6f 100644 --- a/ld_eventsource/testing/test_http_connect_strategy.py +++ b/ld_eventsource/testing/test_http_connect_strategy.py @@ -1,18 +1,20 @@ -from ld_eventsource.config.connect_strategy import * +import logging + +from urllib3.exceptions import ProtocolError +from ld_eventsource.config.connect_strategy import * from ld_eventsource.testing.helpers import * from ld_eventsource.testing.http_util import * -import logging -from urllib3.exceptions import ProtocolError - # Tests of the basic client/request configuration methods and HTTP functionality in # ConnectStrategy.http(), using an embedded HTTP server as a target, but without using # SSEClient. + def logger(): return logging.getLogger("test") + def test_http_request_gets_chunked_data(): with start_server() as server: with make_stream() as stream: @@ -24,6 +26,7 @@ def test_http_request_gets_chunked_data(): stream.push('world') assert next(cxn.stream) == b'world' + def test_http_request_default_headers(): with start_server() as server: with make_stream() as stream: @@ -35,6 +38,7 @@ def test_http_request_default_headers(): assert r.headers['Cache-Control'] == 'no-cache' assert r.headers.get('Last-Event-Id') is None + def test_http_request_custom_default_headers(): with start_server() as server: with make_stream() as stream: @@ -47,6 +51,7 @@ def test_http_request_custom_default_headers(): assert r.headers['Cache-Control'] == 'no-cache' assert r.headers['name1'] == 'value1' + def test_http_request_last_event_id_header(): with start_server() as server: with make_stream() as stream: @@ -57,6 +62,7 @@ def test_http_request_last_event_id_header(): r = server.await_request() assert r.headers['Last-Event-Id'] == 'id123' + def test_http_status_error(): with start_server() as server: server.for_path('/', BasicResponse(400)) @@ -67,9 +73,10 @@ def test_http_status_error(): except HTTPStatusError as e: assert e.status == 400 + def test_http_content_type_error(): with start_server() as server: - with ChunkedResponse({ 'Content-Type': 'text/plain' }) as stream: + with ChunkedResponse({'Content-Type': 'text/plain'}) as stream: server.for_path('/', stream) try: with ConnectStrategy.http(server.uri).create_client(logger()) as client: @@ -78,36 +85,44 @@ def test_http_content_type_error(): except HTTPContentTypeError as e: assert e.content_type == "text/plain" + def test_http_io_error(): with start_server() as server: - server.for_path('/', CauseNetworkError()) - try: - with ConnectStrategy.http(server.uri).create_client(logger()) as client: - client.connect(None) - raise Exception("expected exception, did not get one") - except ProtocolError as e: - pass + server.for_path('/', CauseNetworkError()) + try: + with ConnectStrategy.http(server.uri).create_client(logger()) as client: + client.connect(None) + raise Exception("expected exception, did not get one") + except ProtocolError as e: + pass + def test_auto_redirect_301(): with start_server() as server: with make_stream() as stream: - server.for_path('/', BasicResponse(301, None, {'Location': server.uri + '/real-stream'})) + server.for_path( + '/', BasicResponse(301, None, {'Location': server.uri + '/real-stream'}) + ) server.for_path('/real-stream', stream) with ConnectStrategy.http(server.uri).create_client(logger()) as client: client.connect(None) server.await_request() server.await_request() + def test_auto_redirect_307(): with start_server() as server: with make_stream() as stream: - server.for_path('/', BasicResponse(307, None, {'Location': server.uri + '/real-stream'})) + server.for_path( + '/', BasicResponse(307, None, {'Location': server.uri + '/real-stream'}) + ) server.for_path('/real-stream', stream) with ConnectStrategy.http(server.uri).create_client(logger()) as client: client.connect(None) server.await_request() server.await_request() + def test_sse_client_with_http_connect_strategy(): # Just a basic smoke test to prove that SSEClient interacts with the ConnectStrategy correctly. with start_server() as server: diff --git a/ld_eventsource/testing/test_http_connect_strategy_with_sse_client.py b/ld_eventsource/testing/test_http_connect_strategy_with_sse_client.py index b09d041..5b6bbdf 100644 --- a/ld_eventsource/testing/test_http_connect_strategy_with_sse_client.py +++ b/ld_eventsource/testing/test_http_connect_strategy_with_sse_client.py @@ -1,11 +1,11 @@ from ld_eventsource import * from ld_eventsource.config import * - from ld_eventsource.testing.helpers import * from ld_eventsource.testing.http_util import * # Tests of basic SSEClient behavior using real HTTP requests. + def test_sse_client_reads_events(): with start_server() as server: with make_stream() as stream: @@ -21,53 +21,58 @@ def test_sse_client_reads_events(): assert event2.event == 'b' assert event2.data == 'data2' + def test_sse_client_sends_initial_last_event_id(): with start_server() as server: with make_stream() as stream: server.for_path('/', stream) - with SSEClient(connect=ConnectStrategy.http(server.uri), last_event_id="id123") as client: + with SSEClient( + connect=ConnectStrategy.http(server.uri), last_event_id="id123" + ) as client: client.start() r = server.await_request() assert r.headers['Last-Event-Id'] == 'id123' + def test_sse_client_reconnects_after_socket_closed(): with start_server() as server: - with make_stream() as stream1: - with make_stream() as stream2: - server.for_path('/', SequentialHandler(stream1, stream2)) - stream1.push("event: a\ndata: data1\n\n") - stream2.push("event: b\ndata: data2\n\n") - with SSEClient( - connect=ConnectStrategy.http(server.uri), - error_strategy=ErrorStrategy.always_continue(), - initial_retry_delay=0 - ) as client: - client.start() - event1 = next(client.events) - assert event1.event == 'a' - assert event1.data == 'data1' - stream1.close() - event2 = next(client.events) - assert event2.event == 'b' - assert event2.data == 'data2' + with make_stream() as stream1: + with make_stream() as stream2: + server.for_path('/', SequentialHandler(stream1, stream2)) + stream1.push("event: a\ndata: data1\n\n") + stream2.push("event: b\ndata: data2\n\n") + with SSEClient( + connect=ConnectStrategy.http(server.uri), + error_strategy=ErrorStrategy.always_continue(), + initial_retry_delay=0, + ) as client: + client.start() + event1 = next(client.events) + assert event1.event == 'a' + assert event1.data == 'data1' + stream1.close() + event2 = next(client.events) + assert event2.event == 'b' + assert event2.data == 'data2' + def test_sse_client_sends_last_event_id_on_reconnect(): with start_server() as server: - with make_stream() as stream1: - with make_stream() as stream2: - server.for_path('/', SequentialHandler(stream1, stream2)) - stream1.push("event: a\ndata: data1\nid: id123\n\n") - stream2.push("event: b\ndata: data2\n\n") - with SSEClient( - connect=ConnectStrategy.http(server.uri), - error_strategy=ErrorStrategy.always_continue(), - initial_retry_delay=0 - ) as client: - client.start() - next(client.events) - stream1.close() - next(client.events) - r1 = server.await_request() - assert r1.headers.get('Last-Event-Id') is None - r2 = server.await_request() - assert r2.headers['Last-Event-Id'] == 'id123' + with make_stream() as stream1: + with make_stream() as stream2: + server.for_path('/', SequentialHandler(stream1, stream2)) + stream1.push("event: a\ndata: data1\nid: id123\n\n") + stream2.push("event: b\ndata: data2\n\n") + with SSEClient( + connect=ConnectStrategy.http(server.uri), + error_strategy=ErrorStrategy.always_continue(), + initial_retry_delay=0, + ) as client: + client.start() + next(client.events) + stream1.close() + next(client.events) + r1 = server.await_request() + assert r1.headers.get('Last-Event-Id') is None + r2 = server.await_request() + assert r2.headers['Last-Event-Id'] == 'id123' diff --git a/ld_eventsource/testing/test_reader.py b/ld_eventsource/testing/test_reader.py index 4e00434..2802d37 100644 --- a/ld_eventsource/testing/test_reader.py +++ b/ld_eventsource/testing/test_reader.py @@ -1,48 +1,46 @@ +import pytest + from ld_eventsource.actions import Comment, Event from ld_eventsource.reader import _BufferedLineReader, _SSEReader -import pytest - -class TestBufferedLineReader: - @pytest.fixture(params = ["\r", "\n", "\r\n"]) +class TestBufferedLineReader: + @pytest.fixture(params=["\r", "\n", "\r\n"]) def terminator(self, request): return request.param - - @pytest.fixture(params = [ - [ - [ "first line*", "second line*", "3rd line*" ], - [ "first line", "second line", "3rd line"] - ], - [ - [ "*", "second line*", "3rd line*" ], - [ "", "second line", "3rd line"] - ], - [ - [ "first line*", "*", "3rd line*" ], - [ "first line", "", "3rd line"] - ], - [ - [ "first line*", "*", "*", "*", "3rd line*" ], - [ "first line", "", "", "", "3rd line" ] - ], - [ - [ "first line*second line*third", " line*fourth line*"], - [ "first line", "second line", "third line", "fourth line" ] - ], - ]) + + @pytest.fixture( + params=[ + [ + ["first line*", "second line*", "3rd line*"], + ["first line", "second line", "3rd line"], + ], + [["*", "second line*", "3rd line*"], ["", "second line", "3rd line"]], + [["first line*", "*", "3rd line*"], ["first line", "", "3rd line"]], + [ + ["first line*", "*", "*", "*", "3rd line*"], + ["first line", "", "", "", "3rd line"], + ], + [ + ["first line*second line*third", " line*fourth line*"], + ["first line", "second line", "third line", "fourth line"], + ], + ] + ) def inputs_outputs(self, terminator, request): inputs = list(s.replace("*", terminator).encode() for s in request.param[0]) return [inputs, request.param[1]] def test_parsing(self, inputs_outputs): - assert list(_BufferedLineReader.lines_from(inputs_outputs[0])) == inputs_outputs[1] + assert ( + list(_BufferedLineReader.lines_from(inputs_outputs[0])) == inputs_outputs[1] + ) def test_mixed_terminators(self): chunks = [ b"first line\nsecond line\r\nthird line\r", b"\nfourth line\r", - b"\r\nlast\r\n" + b"\r\nlast\r\n", ] expected = [ "first line", @@ -50,7 +48,7 @@ def test_mixed_terminators(self): "third line", "fourth line", "", - "last" + "last", ] assert list(_BufferedLineReader.lines_from(chunks)) == expected @@ -59,79 +57,52 @@ class TestSSEReader: def expect_output(self, lines, expected): output = list(_SSEReader(lines).events_and_comments) assert output == expected - + def test_parses_event_with_all_fields(self): - lines = [ - "event: abc", - "data: def", - "id: 1", - "" - ] + lines = ["event: abc", "data: def", "id: 1", ""] expected_event = Event("abc", "def", "1", "1") - self.expect_output(lines, [ expected_event ]) + self.expect_output(lines, [expected_event]) def test_parses_event_with_only_data(self): - lines = [ - "data: def", - "" - ] + lines = ["data: def", ""] expected_event = Event("message", "def") - self.expect_output(lines, [ expected_event ]) + self.expect_output(lines, [expected_event]) def test_parses_event_with_multi_line_data(self): - lines = [ - "data: def", - "data: ghi", - "" - ] + lines = ["data: def", "data: ghi", ""] expected_event = Event("message", "def\nghi") - self.expect_output(lines, [ expected_event ]) + self.expect_output(lines, [expected_event]) def test_parses_event_with_empty_data(self): - lines = [ - "data:", - "" - ] + lines = ["data:", ""] expected_event = Event("message", "") - self.expect_output(lines, [ expected_event ]) + self.expect_output(lines, [expected_event]) def test_parses_comment(self): - lines = [ - ":comment", - "data: abc", - "" - ] + lines = [":comment", "data: abc", ""] expected_comment = Comment("comment") expected_event = Event("message", "abc") - self.expect_output(lines, [ expected_comment, expected_event ]) + self.expect_output(lines, [expected_comment, expected_event]) def test_parses_multiple_events(self): - lines = [ - "event: abc", - "data: def", - "", - "data: ghi", - "" - ] + lines = ["event: abc", "data: def", "", "data: ghi", ""] event1 = Event("abc", "def") event2 = Event("message", "ghi") - self.expect_output(lines, [ event1, event2 ]) + self.expect_output(lines, [event1, event2]) def test_parses_retry_interval(self): got_retry = None + def store_retry(value): nonlocal got_retry got_retry = value - lines = [ - "retry: 1000" - ] + + lines = ["retry: 1000"] list(_SSEReader(lines, None, store_retry).events_and_comments) assert got_retry == 1000 def test_ignores_retry_interval_if_no_callback_given(self): - lines = [ - "retry: 1000" - ] + lines = ["retry: 1000"] list(_SSEReader(lines, None, None).events_and_comments) def test_remembers_last_event_id(self): @@ -145,13 +116,13 @@ def test_remembers_last_event_id(self): "", "data: fourth", "id:", - "" + "", ] expected = [ Event("message", "first", None, "a"), Event("message", "second", "b", "b"), Event("message", "third", None, "b"), - Event("message", "fourth", "", "") + Event("message", "fourth", "", ""), ] output = list(_SSEReader(lines, last_event_id="a").events_and_comments) assert output == expected diff --git a/ld_eventsource/testing/test_retry_delay_strategy.py b/ld_eventsource/testing/test_retry_delay_strategy.py index fb5a407..d455cc1 100644 --- a/ld_eventsource/testing/test_retry_delay_strategy.py +++ b/ld_eventsource/testing/test_retry_delay_strategy.py @@ -1,11 +1,13 @@ -from ld_eventsource.config import * - from typing import Optional, Tuple +from ld_eventsource.config import * + def test_backoff_with_no_jitter_and_no_max(): base = 4 - strategy = RetryDelayStrategy.default(max_delay=None, backoff_multiplier=2, jitter_multiplier=None) + strategy = RetryDelayStrategy.default( + max_delay=None, backoff_multiplier=2, jitter_multiplier=None + ) delay, next1 = strategy.apply(base) assert delay == base @@ -23,7 +25,9 @@ def test_backoff_with_no_jitter_and_no_max(): def test_backoff_with_no_jitter_and_max(): base = 4 max = base * 4 + 3 - strategy = RetryDelayStrategy.default(max_delay=max, backoff_multiplier=2, jitter_multiplier=None) + strategy = RetryDelayStrategy.default( + max_delay=max, backoff_multiplier=2, jitter_multiplier=None + ) delay, next1 = strategy.apply(base) assert delay == base @@ -40,7 +44,9 @@ def test_backoff_with_no_jitter_and_max(): def test_no_backoff_and_no_jitter(): base = 4 - strategy = RetryDelayStrategy.default(max_delay=None, backoff_multiplier=1, jitter_multiplier=None) + strategy = RetryDelayStrategy.default( + max_delay=None, backoff_multiplier=1, jitter_multiplier=None + ) delay, next1 = strategy.apply(base) assert delay == base @@ -57,7 +63,9 @@ def test_backoff_with_jitter(): backoff = 2 max = base * backoff * backoff + 3 jitter = 0.25 - strategy = RetryDelayStrategy.default(max_delay=max, backoff_multiplier=backoff, jitter_multiplier=jitter) + strategy = RetryDelayStrategy.default( + max_delay=max, backoff_multiplier=backoff, jitter_multiplier=jitter + ) _, next1 = verify_jitter(strategy, base, base, jitter) _, next2 = verify_jitter(next1, base, base * backoff, jitter) @@ -73,8 +81,9 @@ def zero_base_delay_always_produces_zero(): r = r -def verify_jitter(strategy: RetryDelayStrategy, base: float, base_with_backoff: float, jitter: float) \ - -> Tuple[float, Optional[RetryDelayStrategy]]: +def verify_jitter( + strategy: RetryDelayStrategy, base: float, base_with_backoff: float, jitter: float +) -> Tuple[float, Optional[RetryDelayStrategy]]: # We can't 100% prove that it's using the expected jitter ratio, since the result # is pseudo-random, but we can at least prove that repeated computations don't # fall outside the expected range and aren't all equal. diff --git a/ld_eventsource/testing/test_sse_client_basic.py b/ld_eventsource/testing/test_sse_client_basic.py index e85a25a..8ed0a1c 100644 --- a/ld_eventsource/testing/test_sse_client_basic.py +++ b/ld_eventsource/testing/test_sse_client_basic.py @@ -1,20 +1,22 @@ +import pytest + from ld_eventsource import * from ld_eventsource.actions import * from ld_eventsource.config import * - from ld_eventsource.testing.helpers import * -import pytest - # Tests for SSEClient's basic properties and parsing behavior. These tests do not use real HTTP # requests; instead, they use a ConnectStrategy that provides a preconfigured input stream. HTTP # functionality is tested separately in test_http_connect_strategy.py and # test_http_connect_strategy_with_sse_client.py. + @pytest.mark.parametrize('explicitly_start', [False, True]) def test_receives_events(explicitly_start: bool): mock = MockConnectStrategy( - RespondWithData("event: event1\ndata: data1\n\n:whatever\nevent: event2\ndata: data2\n\n") + RespondWithData( + "event: event1\ndata: data1\n\n:whatever\nevent: event2\ndata: data2\n\n" + ) ) with SSEClient(connect=mock) as client: if explicitly_start: @@ -30,10 +32,9 @@ def test_receives_events(explicitly_start: bool): assert event2.event == 'event2' assert event2.data == 'data2' + def test_events_returns_eof_when_stream_ends(): - mock = MockConnectStrategy( - RespondWithData("event: event1\ndata: data1\n\n") - ) + mock = MockConnectStrategy(RespondWithData("event: event1\ndata: data1\n\n")) with SSEClient(connect=mock) as client: events = client.events @@ -44,9 +45,12 @@ def test_events_returns_eof_when_stream_ends(): event2 = next(events, "done") assert event2 == "done" + def test_receives_all(): mock = MockConnectStrategy( - RespondWithData("event: event1\ndata: data1\n\n:whatever\nevent: event2\ndata: data2\n\n") + RespondWithData( + "event: event1\ndata: data1\n\n:whatever\nevent: event2\ndata: data2\n\n" + ) ) with SSEClient(connect=mock) as client: all = client.all @@ -68,10 +72,9 @@ def test_receives_all(): assert item4.event == 'event2' assert item4.data == 'data2' + def test_all_returns_fault_and_eof_when_stream_ends(): - mock = MockConnectStrategy( - RespondWithData("event: event1\ndata: data1\n\n") - ) + mock = MockConnectStrategy(RespondWithData("event: event1\ndata: data1\n\n")) with SSEClient(connect=mock) as client: all = client.all diff --git a/ld_eventsource/testing/test_sse_client_retry.py b/ld_eventsource/testing/test_sse_client_retry.py index 91611b6..39783de 100644 --- a/ld_eventsource/testing/test_sse_client_retry.py +++ b/ld_eventsource/testing/test_sse_client_retry.py @@ -1,7 +1,6 @@ from ld_eventsource import * from ld_eventsource.actions import * from ld_eventsource.config import * - from ld_eventsource.testing.helpers import * @@ -14,7 +13,7 @@ def test_retry_during_initial_connect_succeeds(): with SSEClient( connect=mock, retry_delay_strategy=no_delay(), - error_strategy=retry_for_status(503) + error_strategy=retry_for_status(503), ) as client: client.start() @@ -22,6 +21,7 @@ def test_retry_during_initial_connect_succeeds(): event1 = next(events) assert event1.data == 'data1' + def test_retry_during_initial_connect_succeeds_then_fails(): mock = MockConnectStrategy( RejectConnection(HTTPStatusError(503)), @@ -32,13 +32,14 @@ def test_retry_during_initial_connect_succeeds_then_fails(): with SSEClient( connect=mock, retry_delay_strategy=no_delay(), - error_strategy=retry_for_status(503) + error_strategy=retry_for_status(503), ) as client: client.start() raise Exception("expected exception, did not get one") except HTTPStatusError as e: assert e.status == 400 + def test_events_iterator_continues_after_retry(): mock = MockConnectStrategy( RespondWithData("data: data1\n\n"), @@ -48,7 +49,7 @@ def test_events_iterator_continues_after_retry(): with SSEClient( connect=mock, error_strategy=ErrorStrategy.always_continue(), - retry_delay_strategy=no_delay() + retry_delay_strategy=no_delay(), ) as client: events = client.events @@ -58,6 +59,7 @@ def test_events_iterator_continues_after_retry(): event2 = next(events) assert event2.data == 'data2' + def test_all_iterator_continues_after_retry(): initial_delay = 0.005 mock = MockConnectStrategy( @@ -70,7 +72,7 @@ def test_all_iterator_continues_after_retry(): connect=mock, error_strategy=ErrorStrategy.always_continue(), initial_retry_delay=initial_delay, - retry_delay_strategy=RetryDelayStrategy.default(jitter_multiplier=None) + retry_delay_strategy=RetryDelayStrategy.default(jitter_multiplier=None), ) as client: all = client.all @@ -98,6 +100,7 @@ def test_all_iterator_continues_after_retry(): assert item6.error is None assert client.next_retry_delay == initial_delay * 2 + def test_can_interrupt_and_restart_stream(): initial_delay = 0.005 mock = MockConnectStrategy( @@ -109,7 +112,7 @@ def test_can_interrupt_and_restart_stream(): connect=mock, error_strategy=ErrorStrategy.always_continue(), initial_retry_delay=initial_delay, - retry_delay_strategy=RetryDelayStrategy.default(jitter_multiplier=None) + retry_delay_strategy=RetryDelayStrategy.default(jitter_multiplier=None), ) as client: all = client.all diff --git a/pyproject.toml b/pyproject.toml index eed4dea..c4eeb53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,8 @@ urllib3 = ">=1.26.0,<3" mock = ">=2.0.0" pytest = ">=2.8" mypy = "^1.4.0" +pycodestyle = "^2.12.1" +isort = "^5.13.2" [tool.poetry.group.contract-tests] diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..1fb1827 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[pycodestyle] +ignore = E501,W503