Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions src/logicblocks/event/projection/projector_with_none.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import inspect
from abc import ABC
from collections.abc import Callable, Mapping
from enum import StrEnum, auto
from inspect import Parameter
from types import NoneType, UnionType
from typing import Any, Union, get_args, get_origin

from logicblocks.event.types import (
EventSourceIdentifier,
StoredEvent,
)

from .projector import Projector


class StateAnnotationType(StrEnum):
NONE = auto()
STATE = auto()
NONE_OR_STATE = auto()
NO_ANNOTATION = auto()


def get_param_annotation(param: Parameter | None) -> Any | Parameter.empty:
if param is None:
return Parameter.empty

return param.annotation


def get_state_annotation_type(
annotation: Any | Parameter.empty,
) -> StateAnnotationType:
if annotation == Parameter.empty:
return StateAnnotationType.NO_ANNOTATION

if annotation is None:
return StateAnnotationType.NONE

if (
get_origin(annotation) is Union or get_origin(annotation) is UnionType
) and NoneType in get_args(annotation):
return StateAnnotationType.NONE_OR_STATE

return StateAnnotationType.STATE


def extract_state_annotations_from_handler(
handler: Callable[..., Any],
) -> tuple[StateAnnotationType, Any]:
handler_sig = inspect.signature(handler)
state_annotation = get_param_annotation(
handler_sig.parameters.get("state")
)
state_annotation_type = get_state_annotation_type(state_annotation)

return state_annotation_type, state_annotation


class ProjectorWithNone[
State,
Identifier: EventSourceIdentifier,
Metadata = Mapping[str, Any],
](Projector[State | None, Identifier, Metadata], ABC):
def initial_state_factory(self) -> State | None:
return None

def _resolve_handler(
self, event: StoredEvent
) -> Callable[[State | None, StoredEvent], State | None]:
handler = super()._resolve_handler(event)
state_annotation_type, state_annotation = (
extract_state_annotations_from_handler(handler)
)

def _wrapped_handler(
state: State | None, event: StoredEvent
) -> State | None:
match state_annotation_type, state:
case StateAnnotationType.NONE_OR_STATE, _:
return handler(state, event)

case StateAnnotationType.NONE, None:
return handler(state, event)
case StateAnnotationType.NONE, _:
raise ValueError(
f"Initial handler {handler.__name__} should not be passed a state"
)

case StateAnnotationType.STATE, None:
raise ValueError(
f"Handler {handler.__name__} was passed None but expected {state_annotation}"
)
case StateAnnotationType.STATE, _:
return handler(state, event)

case StateAnnotationType.NO_ANNOTATION, _:
return handler(state, event)

return _wrapped_handler
156 changes: 156 additions & 0 deletions tests/unit/logicblocks/event/projection/test_projector_with_none.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from inspect import Parameter
from typing import Any, Mapping, Optional, Union

import pytest

from logicblocks.event.projection.projector_with_none import (
ProjectorWithNone,
StateAnnotationType,
extract_state_annotations_from_handler,
)
from logicblocks.event.sources import InMemoryEventSource
from logicblocks.event.testing import StoredEventBuilder
from logicblocks.event.types import LogIdentifier, StoredEvent


class TestExtractStateAnnotationsFromHandler:
def test_no_arg(self):
def handler(event):
pass

assert extract_state_annotations_from_handler(handler) == (
StateAnnotationType.NO_ANNOTATION,
Parameter.empty,
)

def test_no_annotation(self):
def handler(state, event):
pass

assert extract_state_annotations_from_handler(handler) == (
StateAnnotationType.NO_ANNOTATION,
Parameter.empty,
)

def test_with_lambda(self):
assert extract_state_annotations_from_handler(
lambda state, event: ...
) == (
StateAnnotationType.NO_ANNOTATION,
Parameter.empty,
)

def test_none(self):
def handler(state: None, event):
pass

assert extract_state_annotations_from_handler(handler) == (
StateAnnotationType.NONE,
None,
)

def test_optional(self):
def handler(state: Optional[str], event):
pass

assert extract_state_annotations_from_handler(handler) == (
StateAnnotationType.NONE_OR_STATE,
Optional[str],
)

def test_union_none(self):
def handler(state: Union[str, None], event):
pass

assert extract_state_annotations_from_handler(handler) == (
StateAnnotationType.NONE_OR_STATE,
Union[str, None],
)

def test_union_none_shorthand(self):
def handler(state: str | None, event):
pass

assert extract_state_annotations_from_handler(handler) == (
StateAnnotationType.NONE_OR_STATE,
str | None,
)

def test_other_annotation(self):
def handler(state: str, event):
pass

assert extract_state_annotations_from_handler(handler) == (
StateAnnotationType.STATE,
str,
)


class TestProjectorWithNone:
class MyProjector(ProjectorWithNone[Mapping[str, str], LogIdentifier]):
def initial_metadata_factory(self) -> Mapping[str, Any]:
return {}

def id_factory(
self, state: Mapping[str, str] | None, source: LogIdentifier
) -> str:
return str(source)

@staticmethod
def test_first_event(
state: None, event: StoredEvent
) -> Mapping[str, str]:
return {}

@staticmethod
def test_later_event(
state: Mapping[str, str], event: StoredEvent
) -> Mapping[str, str]:
return state

@staticmethod
def test_any_position_event(
state: Mapping[str, str] | None, event: StoredEvent
) -> Mapping[str, str]:
return {}

async def test_events_work_in_correct_position(self):
events = [
StoredEventBuilder(name="test_first_event", position=0).build(),
StoredEventBuilder(name="test_later_event", position=1).build(),
]
source = InMemoryEventSource(events, LogIdentifier())

await self.MyProjector().project(source=source)

async def test_later_event_raises_exception_is_used_first(self):
events = [
StoredEventBuilder(name="test_later_event", position=0).build(),
]
source = InMemoryEventSource(events, LogIdentifier())

with pytest.raises(ValueError):
await self.MyProjector().project(source=source)

async def test_first_event_raises_exception_is_used_later(self):
events = [
StoredEventBuilder(name="test_later_event", position=0).build(),
StoredEventBuilder(name="test_first_event", position=1).build(),
]
source = InMemoryEventSource(events, LogIdentifier())

with pytest.raises(ValueError):
await self.MyProjector().project(source=source)

async def test_any_events_work_in_any_position(self):
events = [
StoredEventBuilder(
name="test_any_position_event", position=0
).build(),
StoredEventBuilder(
name="test_any_position_event", position=1
).build(),
]
source = InMemoryEventSource(events, LogIdentifier())

await self.MyProjector().project(source=source)