From 55e47f66a11b026cd325313ccfa8181f8dbc4745 Mon Sep 17 00:00:00 2001 From: Corin Chaplin Date: Wed, 19 Feb 2025 09:56:52 +0000 Subject: [PATCH] PoC None Projector --- .../event/projection/projector_with_none.py | 100 +++++++++++ .../projection/test_projector_with_none.py | 156 ++++++++++++++++++ 2 files changed, 256 insertions(+) create mode 100644 src/logicblocks/event/projection/projector_with_none.py create mode 100644 tests/unit/logicblocks/event/projection/test_projector_with_none.py diff --git a/src/logicblocks/event/projection/projector_with_none.py b/src/logicblocks/event/projection/projector_with_none.py new file mode 100644 index 00000000..94033b6f --- /dev/null +++ b/src/logicblocks/event/projection/projector_with_none.py @@ -0,0 +1,100 @@ +import inspect +from abc import ABC +from collections.abc import Callable, Mapping +from enum import StrEnum, auto +from inspect import Parameter +from types import NoneType, UnionType +from typing import Any, Union, get_args, get_origin + +from logicblocks.event.types import ( + EventSourceIdentifier, + StoredEvent, +) + +from .projector import Projector + + +class StateAnnotationType(StrEnum): + NONE = auto() + STATE = auto() + NONE_OR_STATE = auto() + NO_ANNOTATION = auto() + + +def get_param_annotation(param: Parameter | None) -> Any | Parameter.empty: + if param is None: + return Parameter.empty + + return param.annotation + + +def get_state_annotation_type( + annotation: Any | Parameter.empty, +) -> StateAnnotationType: + if annotation == Parameter.empty: + return StateAnnotationType.NO_ANNOTATION + + if annotation is None: + return StateAnnotationType.NONE + + if ( + get_origin(annotation) is Union or get_origin(annotation) is UnionType + ) and NoneType in get_args(annotation): + return StateAnnotationType.NONE_OR_STATE + + return StateAnnotationType.STATE + + +def extract_state_annotations_from_handler( + handler: Callable[..., Any], +) -> tuple[StateAnnotationType, Any]: + handler_sig = inspect.signature(handler) + state_annotation = get_param_annotation( + handler_sig.parameters.get("state") + ) + state_annotation_type = get_state_annotation_type(state_annotation) + + return state_annotation_type, state_annotation + + +class ProjectorWithNone[ + State, + Identifier: EventSourceIdentifier, + Metadata = Mapping[str, Any], +](Projector[State | None, Identifier, Metadata], ABC): + def initial_state_factory(self) -> State | None: + return None + + def _resolve_handler( + self, event: StoredEvent + ) -> Callable[[State | None, StoredEvent], State | None]: + handler = super()._resolve_handler(event) + state_annotation_type, state_annotation = ( + extract_state_annotations_from_handler(handler) + ) + + def _wrapped_handler( + state: State | None, event: StoredEvent + ) -> State | None: + match state_annotation_type, state: + case StateAnnotationType.NONE_OR_STATE, _: + return handler(state, event) + + case StateAnnotationType.NONE, None: + return handler(state, event) + case StateAnnotationType.NONE, _: + raise ValueError( + f"Initial handler {handler.__name__} should not be passed a state" + ) + + case StateAnnotationType.STATE, None: + raise ValueError( + f"Handler {handler.__name__} was passed None but expected {state_annotation}" + ) + case StateAnnotationType.STATE, _: + return handler(state, event) + + case StateAnnotationType.NO_ANNOTATION, _: + return handler(state, event) + + return _wrapped_handler diff --git a/tests/unit/logicblocks/event/projection/test_projector_with_none.py b/tests/unit/logicblocks/event/projection/test_projector_with_none.py new file mode 100644 index 00000000..fb5549dd --- /dev/null +++ b/tests/unit/logicblocks/event/projection/test_projector_with_none.py @@ -0,0 +1,156 @@ +from inspect import Parameter +from typing import Any, Mapping, Optional, Union + +import pytest + +from logicblocks.event.projection.projector_with_none import ( + ProjectorWithNone, + StateAnnotationType, + extract_state_annotations_from_handler, +) +from logicblocks.event.sources import InMemoryEventSource +from logicblocks.event.testing import StoredEventBuilder +from logicblocks.event.types import LogIdentifier, StoredEvent + + +class TestExtractStateAnnotationsFromHandler: + def test_no_arg(self): + def handler(event): + pass + + assert extract_state_annotations_from_handler(handler) == ( + StateAnnotationType.NO_ANNOTATION, + Parameter.empty, + ) + + def test_no_annotation(self): + def handler(state, event): + pass + + assert extract_state_annotations_from_handler(handler) == ( + StateAnnotationType.NO_ANNOTATION, + Parameter.empty, + ) + + def test_with_lambda(self): + assert extract_state_annotations_from_handler( + lambda state, event: ... + ) == ( + StateAnnotationType.NO_ANNOTATION, + Parameter.empty, + ) + + def test_none(self): + def handler(state: None, event): + pass + + assert extract_state_annotations_from_handler(handler) == ( + StateAnnotationType.NONE, + None, + ) + + def test_optional(self): + def handler(state: Optional[str], event): + pass + + assert extract_state_annotations_from_handler(handler) == ( + StateAnnotationType.NONE_OR_STATE, + Optional[str], + ) + + def test_union_none(self): + def handler(state: Union[str, None], event): + pass + + assert extract_state_annotations_from_handler(handler) == ( + StateAnnotationType.NONE_OR_STATE, + Union[str, None], + ) + + def test_union_none_shorthand(self): + def handler(state: str | None, event): + pass + + assert extract_state_annotations_from_handler(handler) == ( + StateAnnotationType.NONE_OR_STATE, + str | None, + ) + + def test_other_annotation(self): + def handler(state: str, event): + pass + + assert extract_state_annotations_from_handler(handler) == ( + StateAnnotationType.STATE, + str, + ) + + +class TestProjectorWithNone: + class MyProjector(ProjectorWithNone[Mapping[str, str], LogIdentifier]): + def initial_metadata_factory(self) -> Mapping[str, Any]: + return {} + + def id_factory( + self, state: Mapping[str, str] | None, source: LogIdentifier + ) -> str: + return str(source) + + @staticmethod + def test_first_event( + state: None, event: StoredEvent + ) -> Mapping[str, str]: + return {} + + @staticmethod + def test_later_event( + state: Mapping[str, str], event: StoredEvent + ) -> Mapping[str, str]: + return state + + @staticmethod + def test_any_position_event( + state: Mapping[str, str] | None, event: StoredEvent + ) -> Mapping[str, str]: + return {} + + async def test_events_work_in_correct_position(self): + events = [ + StoredEventBuilder(name="test_first_event", position=0).build(), + StoredEventBuilder(name="test_later_event", position=1).build(), + ] + source = InMemoryEventSource(events, LogIdentifier()) + + await self.MyProjector().project(source=source) + + async def test_later_event_raises_exception_is_used_first(self): + events = [ + StoredEventBuilder(name="test_later_event", position=0).build(), + ] + source = InMemoryEventSource(events, LogIdentifier()) + + with pytest.raises(ValueError): + await self.MyProjector().project(source=source) + + async def test_first_event_raises_exception_is_used_later(self): + events = [ + StoredEventBuilder(name="test_later_event", position=0).build(), + StoredEventBuilder(name="test_first_event", position=1).build(), + ] + source = InMemoryEventSource(events, LogIdentifier()) + + with pytest.raises(ValueError): + await self.MyProjector().project(source=source) + + async def test_any_events_work_in_any_position(self): + events = [ + StoredEventBuilder( + name="test_any_position_event", position=0 + ).build(), + StoredEventBuilder( + name="test_any_position_event", position=1 + ).build(), + ] + source = InMemoryEventSource(events, LogIdentifier()) + + await self.MyProjector().project(source=source)