Skip to content

Commit 9f9abc5

Browse files
committed
Extracted BaseInfrastructureFactory, so that "normal" and DCB factory classes can live in the same Python module, and also be detected and used separately.
1 parent e30f801 commit 9f9abc5

File tree

6 files changed

+101
-87
lines changed

6 files changed

+101
-87
lines changed

eventsourcing/dcb/persistence.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
Tagged,
1818
TMutates,
1919
)
20-
from eventsourcing.persistence import InfrastructureFactory, TTrackingRecorder
20+
from eventsourcing.persistence import BaseInfrastructureFactory, TTrackingRecorder
2121
from eventsourcing.utils import get_topic
2222

2323
if TYPE_CHECKING:
@@ -105,7 +105,7 @@ class NotFoundError(Exception):
105105
pass
106106

107107

108-
class DCBInfrastructureFactory(InfrastructureFactory[TTrackingRecorder], ABC):
108+
class DCBInfrastructureFactory(BaseInfrastructureFactory[TTrackingRecorder], ABC):
109109
@abstractmethod
110110
def dcb_event_store(self) -> DCBRecorder:
111111
pass # pragma: no cover

eventsourcing/dcb/postgres_tt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from eventsourcing.dcb.popo import SimpleDCBReadResponse
1818
from eventsourcing.persistence import IntegrityError, ProgrammingError
1919
from eventsourcing.postgres import (
20+
BasePostgresFactory,
2021
PostgresDatastore,
21-
PostgresFactory,
2222
PostgresRecorder,
2323
PostgresTrackingRecorder,
2424
)
@@ -601,7 +601,7 @@ class PsycopgDCBQueryItem(NamedTuple):
601601

602602

603603
class PostgresTTDCBFactory(
604-
PostgresFactory,
604+
BasePostgresFactory[PostgresTrackingRecorder],
605605
DCBInfrastructureFactory[PostgresTrackingRecorder],
606606
):
607607
def dcb_event_store(self) -> DCBRecorder:

eventsourcing/persistence.py

Lines changed: 57 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -663,18 +663,33 @@ class InfrastructureFactoryError(EventSourcingError):
663663
"""Raised when an infrastructure factory cannot be created."""
664664

665665

666-
class InfrastructureFactory(ABC, Generic[TTrackingRecorder]):
666+
class BaseInfrastructureFactory(ABC, Generic[TTrackingRecorder]):
667667
"""Abstract base class for infrastructure factories."""
668668

669669
PERSISTENCE_MODULE = "PERSISTENCE_MODULE"
670670
TRANSCODER_TOPIC = "TRANSCODER_TOPIC"
671-
MAPPER_TOPIC = "MAPPER_TOPIC"
672671
CIPHER_TOPIC = "CIPHER_TOPIC"
673672
COMPRESSOR_TOPIC = "COMPRESSOR_TOPIC"
674-
IS_SNAPSHOTTING_ENABLED = "IS_SNAPSHOTTING_ENABLED"
675-
APPLICATION_RECORDER_TOPIC = "APPLICATION_RECORDER_TOPIC"
676-
TRACKING_RECORDER_TOPIC = "TRACKING_RECORDER_TOPIC"
677-
PROCESS_RECORDER_TOPIC = "PROCESS_RECORDER_TOPIC"
673+
674+
def __init__(self, env: Environment | EnvType | None):
675+
"""Initialises infrastructure factory object with given application name."""
676+
self.env = env if isinstance(env, Environment) else Environment(env=env)
677+
self._is_entered = False
678+
679+
def __enter__(self) -> Self:
680+
self._is_entered = True
681+
return self
682+
683+
def __exit__(
684+
self,
685+
exc_type: type[BaseException] | None,
686+
exc_val: BaseException | None,
687+
exc_tb: TracebackType | None,
688+
) -> None:
689+
self._is_entered = False
690+
691+
def close(self) -> None:
692+
"""Closes any database connections, and anything else that needs closing."""
678693

679694
@classmethod
680695
def construct(
@@ -747,11 +762,6 @@ def construct(
747762
raise InfrastructureFactoryError(msg)
748763
return factory_cls(env=env)
749764

750-
def __init__(self, env: Environment | EnvType | None):
751-
"""Initialises infrastructure factory object with given application name."""
752-
self.env = env if isinstance(env, Environment) else Environment(env=env)
753-
self._is_entered = False
754-
755765
def transcoder(
756766
self,
757767
) -> Transcoder:
@@ -763,32 +773,6 @@ def transcoder(
763773
transcoder_class = JSONTranscoder
764774
return transcoder_class()
765775

766-
def mapper(
767-
self,
768-
transcoder: Transcoder | None = None,
769-
mapper_class: type[Mapper[TAggregateID]] | None = None,
770-
) -> Mapper[TAggregateID]:
771-
"""Constructs a mapper."""
772-
# Resolve MAPPER_TOPIC if no given class.
773-
if mapper_class is None:
774-
mapper_topic = self.env.get(self.MAPPER_TOPIC)
775-
mapper_class = (
776-
resolve_topic(mapper_topic) if mapper_topic else Mapper[TAggregateID]
777-
)
778-
779-
# Check we have a mapper class.
780-
assert mapper_class is not None
781-
origin_mapper_class = typing.get_origin(mapper_class) or mapper_class
782-
assert isinstance(origin_mapper_class, type), mapper_class
783-
assert issubclass(origin_mapper_class, Mapper), mapper_class
784-
785-
# Construct and return a mapper.
786-
return mapper_class(
787-
transcoder=transcoder or self.transcoder(),
788-
cipher=self.cipher(),
789-
compressor=self.compressor(),
790-
)
791-
792776
def cipher(self) -> Cipher | None:
793777
"""Reads environment variables 'CIPHER_TOPIC'
794778
and 'CIPHER_KEY' to decide whether or not
@@ -822,6 +806,42 @@ def compressor(self) -> Compressor | None:
822806
compressor = compressor_cls
823807
return compressor
824808

809+
810+
class InfrastructureFactory(BaseInfrastructureFactory[TTrackingRecorder]):
811+
"""Abstract base class for Application factories."""
812+
813+
MAPPER_TOPIC = "MAPPER_TOPIC"
814+
IS_SNAPSHOTTING_ENABLED = "IS_SNAPSHOTTING_ENABLED"
815+
APPLICATION_RECORDER_TOPIC = "APPLICATION_RECORDER_TOPIC"
816+
TRACKING_RECORDER_TOPIC = "TRACKING_RECORDER_TOPIC"
817+
PROCESS_RECORDER_TOPIC = "PROCESS_RECORDER_TOPIC"
818+
819+
def mapper(
820+
self,
821+
transcoder: Transcoder | None = None,
822+
mapper_class: type[Mapper[TAggregateID]] | None = None,
823+
) -> Mapper[TAggregateID]:
824+
"""Constructs a mapper."""
825+
# Resolve MAPPER_TOPIC if no given class.
826+
if mapper_class is None:
827+
mapper_topic = self.env.get(self.MAPPER_TOPIC)
828+
mapper_class = (
829+
resolve_topic(mapper_topic) if mapper_topic else Mapper[TAggregateID]
830+
)
831+
832+
# Check we have a mapper class.
833+
assert mapper_class is not None
834+
origin_mapper_class = typing.get_origin(mapper_class) or mapper_class
835+
assert isinstance(origin_mapper_class, type), mapper_class
836+
assert issubclass(origin_mapper_class, Mapper), mapper_class
837+
838+
# Construct and return a mapper.
839+
return mapper_class(
840+
transcoder=transcoder or self.transcoder(),
841+
cipher=self.cipher(),
842+
compressor=self.compressor(),
843+
)
844+
825845
def event_store(
826846
self,
827847
mapper: Mapper[TAggregateID] | None = None,
@@ -858,21 +878,6 @@ def is_snapshotting_enabled(self) -> bool:
858878
"""
859879
return strtobool(self.env.get(self.IS_SNAPSHOTTING_ENABLED, "no"))
860880

861-
def __enter__(self) -> Self:
862-
self._is_entered = True
863-
return self
864-
865-
def __exit__(
866-
self,
867-
exc_type: type[BaseException] | None,
868-
exc_val: BaseException | None,
869-
exc_tb: TracebackType | None,
870-
) -> None:
871-
self._is_entered = False
872-
873-
def close(self) -> None:
874-
"""Closes any database connections, and anything else that needs closing."""
875-
876881

877882
@dataclass(frozen=True)
878883
class Tracking:

eventsourcing/postgres.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from eventsourcing.persistence import (
2929
AggregateRecorder,
3030
ApplicationRecorder,
31+
BaseInfrastructureFactory,
3132
DatabaseError,
3233
DataError,
3334
InfrastructureFactory,
@@ -45,6 +46,7 @@
4546
Subscription,
4647
Tracking,
4748
TrackingRecorder,
49+
TTrackingRecorder,
4850
)
4951
from eventsourcing.utils import Environment, EnvType, resolve_topic, retry, strtobool
5052

@@ -1109,7 +1111,7 @@ def _insert_events(
11091111
super()._insert_events(curs, stored_events, **kwargs)
11101112

11111113

1112-
class PostgresFactory(InfrastructureFactory[PostgresTrackingRecorder]):
1114+
class BasePostgresFactory(BaseInfrastructureFactory[TTrackingRecorder]):
11131115
POSTGRES_DBNAME = "POSTGRES_DBNAME"
11141116
POSTGRES_HOST = "POSTGRES_HOST"
11151117
POSTGRES_PORT = "POSTGRES_PORT"
@@ -1132,11 +1134,6 @@ class PostgresFactory(InfrastructureFactory[PostgresTrackingRecorder]):
11321134
POSTGRES_ENABLE_DB_FUNCTIONS = "POSTGRES_ENABLE_DB_FUNCTIONS"
11331135
CREATE_TABLE = "CREATE_TABLE"
11341136

1135-
aggregate_recorder_class = PostgresAggregateRecorder
1136-
application_recorder_class = PostgresApplicationRecorder
1137-
tracking_recorder_class = PostgresTrackingRecorder
1138-
process_recorder_class = PostgresProcessRecorder
1139-
11401137
def __init__(self, env: Environment | EnvType | None):
11411138
super().__init__(env)
11421139
dbname = self.env.get(self.POSTGRES_DBNAME)
@@ -1334,6 +1331,32 @@ def __init__(self, env: Environment | EnvType | None):
13341331
def env_create_table(self) -> bool:
13351332
return strtobool(self.env.get(self.CREATE_TABLE) or "yes")
13361333

1334+
def __enter__(self) -> Self:
1335+
self.datastore.__enter__()
1336+
return self
1337+
1338+
def __exit__(
1339+
self,
1340+
exc_type: type[BaseException] | None,
1341+
exc_val: BaseException | None,
1342+
exc_tb: TracebackType | None,
1343+
) -> None:
1344+
self.datastore.__exit__(exc_type, exc_val, exc_tb)
1345+
1346+
def close(self) -> None:
1347+
with contextlib.suppress(AttributeError):
1348+
self.datastore.close()
1349+
1350+
1351+
class PostgresFactory(
1352+
BasePostgresFactory[PostgresTrackingRecorder],
1353+
InfrastructureFactory[PostgresTrackingRecorder],
1354+
):
1355+
aggregate_recorder_class = PostgresAggregateRecorder
1356+
application_recorder_class = PostgresApplicationRecorder
1357+
tracking_recorder_class = PostgresTrackingRecorder
1358+
process_recorder_class = PostgresProcessRecorder
1359+
13371360
def aggregate_recorder(self, purpose: str = "events") -> AggregateRecorder:
13381361
prefix = self.env.name.lower() or "stored"
13391362
events_table_name = prefix + "_" + purpose
@@ -1412,21 +1435,5 @@ def process_recorder(self) -> ProcessRecorder:
14121435
recorder.create_table()
14131436
return recorder
14141437

1415-
def __enter__(self) -> Self:
1416-
self.datastore.__enter__()
1417-
return self
1418-
1419-
def __exit__(
1420-
self,
1421-
exc_type: type[BaseException] | None,
1422-
exc_val: BaseException | None,
1423-
exc_tb: TracebackType | None,
1424-
) -> None:
1425-
self.datastore.__exit__(exc_type, exc_val, exc_tb)
1426-
1427-
def close(self) -> None:
1428-
with contextlib.suppress(AttributeError):
1429-
self.datastore.close()
1430-
14311438

14321439
Factory = PostgresFactory

eventsourcing/tests/persistence.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ def insert_events() -> None:
629629
rate = num_jobs * num_events_per_job / (ended - started).total_seconds()
630630
print(f"Rate: {rate:.0f} inserts per second")
631631

632-
def optional_test_insert_subscribe(self) -> None:
632+
def optional_test_insert_subscribe(self, initial_position: int = 0) -> None:
633633

634634
recorder = self.create_recorder()
635635

@@ -655,7 +655,9 @@ def optional_test_insert_subscribe(self) -> None:
655655

656656
notification_ids = recorder.insert_events([stored_event1, stored_event2])
657657
if self.EXPECT_CONTIGUOUS_NOTIFICATION_IDS:
658-
self.assertEqual(notification_ids, [1, 2])
658+
self.assertEqual(
659+
notification_ids, [1 + initial_position, 2 + initial_position]
660+
)
659661

660662
# Get the max notification ID.
661663
max_notification_id2 = recorder.max_notification_id()
@@ -698,8 +700,8 @@ def optional_test_insert_subscribe(self) -> None:
698700
stored_event2.originator_version, notifications[1].originator_version
699701
)
700702
if self.EXPECT_CONTIGUOUS_NOTIFICATION_IDS:
701-
self.assertEqual(1, notifications[0].id)
702-
self.assertEqual(2, notifications[1].id)
703+
self.assertEqual(1 + initial_position, notifications[0].id)
704+
self.assertEqual(2 + initial_position, notifications[1].id)
703705

704706
# Store a third event.
705707
stored_event3 = StoredEvent(
@@ -710,7 +712,7 @@ def optional_test_insert_subscribe(self) -> None:
710712
)
711713
notification_ids = recorder.insert_events([stored_event3])
712714
if self.EXPECT_CONTIGUOUS_NOTIFICATION_IDS:
713-
self.assertEqual(notification_ids, [3])
715+
self.assertEqual(notification_ids, [3 + initial_position])
714716

715717
# Receive events from the subscription.
716718
for notification in subscription:
@@ -726,7 +728,7 @@ def optional_test_insert_subscribe(self) -> None:
726728
stored_event3.originator_version, notifications[2].originator_version
727729
)
728730
if self.EXPECT_CONTIGUOUS_NOTIFICATION_IDS:
729-
self.assertEqual(3, notifications[2].id)
731+
self.assertEqual(3 + initial_position, notifications[2].id)
730732

731733
# Start a subscription with int value for 'start'.
732734
with recorder.subscribe(gt=max_notification_id2) as subscription:
@@ -744,7 +746,7 @@ def optional_test_insert_subscribe(self) -> None:
744746
)
745747

746748
# Start a subscription, call stop() during iteration.
747-
with recorder.subscribe(gt=None) as subscription:
749+
with recorder.subscribe(gt=initial_position) as subscription:
748750

749751
# Receive events from the subscription.
750752
for i, _ in enumerate(subscription):

examples/coursebookingdcb/postgres_ts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from eventsourcing.dcb.popo import SimpleDCBReadResponse
1818
from eventsourcing.persistence import IntegrityError, ProgrammingError
1919
from eventsourcing.postgres import (
20+
BasePostgresFactory,
2021
PostgresDatastore,
21-
PostgresFactory,
2222
PostgresRecorder,
2323
PostgresTrackingRecorder,
2424
)
@@ -447,7 +447,7 @@ class PgDCBEventRow(TypedDict):
447447

448448

449449
class PostgresTSDCBFactory(
450-
PostgresFactory,
450+
BasePostgresFactory[PostgresTrackingRecorder],
451451
DCBInfrastructureFactory[PostgresTrackingRecorder],
452452
):
453453
def dcb_event_store(self) -> DCBRecorder:

0 commit comments

Comments
 (0)