Skip to content

Commit 51b4553

Browse files
committed
check for return type from task handlers
1 parent c71b5af commit 51b4553

File tree

4 files changed

+60
-49
lines changed

4 files changed

+60
-49
lines changed

src/modules/bidding/application/event.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
from modules.bidding.domain.repositories import ListingRepository
66
from modules.bidding.domain.value_objects import Seller
77
from modules.catalog.domain.events import ListingPublishedEvent
8+
from seedwork.application import EventResult
89

910

1011
@bidding_module.domain_event_handler
1112
def when_listing_is_published_start_auction(
1213
event: ListingPublishedEvent, listing_repository: ListingRepository
13-
):
14+
) -> EventResult:
1415
listing = Listing(
1516
id=event.listing_id,
1617
seller=Seller(id=event.seller_id),
@@ -19,3 +20,4 @@ def when_listing_is_published_start_auction(
1920
ends_at=datetime.now() + timedelta(days=7),
2021
)
2122
listing_repository.add(listing)
23+
return EventResult.success()
Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,24 @@
11
import pytest
2-
from sqlalchemy.orm import Session
32

3+
from modules.bidding.domain.repositories import (
4+
ListingRepository as BiddingListingRepository,
5+
)
46
from modules.catalog.domain.events import ListingPublishedEvent
57
from seedwork.domain.value_objects import UUID, Money
68

79

8-
@pytest.mark.skip # this test needs to be fixed
910
@pytest.mark.integration
10-
def test_create_listing_on_draft_published_event(engine):
11-
module = BiddingModule()
12-
listing_id = UUID.v4()
13-
with Session(engine) as db_session:
14-
with module.unit_of_work(db_session=db_session) as uow:
15-
module.handle_domain_event(
16-
ListingPublishedEvent(
17-
listing_id=listing_id,
18-
seller_id=UUID.v4(),
19-
ask_price=Money(10),
20-
)
11+
def test_create_listing_on_draft_published_event(app, engine):
12+
listing_id = UUID(int=1)
13+
with app.transaction_context() as ctx:
14+
ctx.handle_domain_event(
15+
ListingPublishedEvent(
16+
listing_id=listing_id,
17+
seller_id=UUID.v4(),
18+
ask_price=Money(10),
2119
)
22-
db_session.commit()
20+
)
2321

24-
with Session(engine) as db_session:
25-
with module.unit_of_work(db_session=db_session) as uow:
26-
assert uow.listing_repository.count() == 1
22+
with app.transaction_context() as ctx:
23+
listing_repository = ctx.get_service(BiddingListingRepository)
24+
assert listing_repository.count() == 1

src/seedwork/application/__init__.py

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def _wrap_with_middlewares(
117117
p = partial(middleware, self, p, command, query, event)
118118
return p
119119

120-
def execute_query(self, query):
120+
def execute_query(self, query) -> QueryResult:
121121
assert (
122122
self.task is None
123123
), "Cannot execute query while another task is being executed"
@@ -127,12 +127,15 @@ def execute_query(self, query):
127127
handler_kwargs = self.dependency_provider.get_handler_kwargs(
128128
handler_func, **self.overrides
129129
)
130-
handler_func = partial(handler_func, query, **handler_kwargs)
131-
wrapped_handler = self._wrap_with_middlewares(handler_func, query=query)
130+
p = partial(handler_func, query, **handler_kwargs)
131+
wrapped_handler = self._wrap_with_middlewares(p, query=query)
132132
result = wrapped_handler()
133+
assert isinstance(
134+
result, QueryResult
135+
), f"Got {result} instead of QueryResult from {handler_func}"
133136
return result
134137

135-
def execute_command(self, command):
138+
def execute_command(self, command) -> CommandResult:
136139
assert (
137140
self.task is None
138141
), "Cannot execute command while another task is being executed"
@@ -142,53 +145,55 @@ def execute_command(self, command):
142145
handler_kwargs = self.dependency_provider.get_handler_kwargs(
143146
handler_func, **self.overrides
144147
)
145-
handler_func = partial(handler_func, command, **handler_kwargs)
146-
wrapped_handler = self._wrap_with_middlewares(handler_func, command=command)
148+
p = partial(handler_func, command, **handler_kwargs)
149+
wrapped_handler = self._wrap_with_middlewares(p, command=command)
147150

148151
# execute wrapped command handler
149152
command_result = wrapped_handler()
153+
assert isinstance(
154+
command_result, CommandResult
155+
), f"Got {command_result} instead of CommandResult from {handler_func}"
150156

151157
self.next_commands = []
152158
self.integration_events = []
153159
event_queue = command_result.events.copy()
154160
while len(event_queue) > 0:
155161
event = event_queue.pop(0)
156162
if isinstance(event, IntegrationEvent):
157-
self._process_integration_event(event)
163+
self.collect_integration_event(event)
158164

159165
elif isinstance(event, DomainEvent):
160-
new_command, new_events = self._process_domain_event(event)
161-
self.next_commands.extend(new_command)
162-
event_queue.extend(new_events)
166+
event_results = self.handle_domain_event(event)
167+
self.next_commands.extend(event_results.commands)
168+
event_queue.extend(event_results.events)
163169

164170
return CommandResult.success(payload=command_result.payload)
165171

172+
def handle_domain_event(self, event) -> EventResultSet:
173+
event_results = []
174+
for handler_func in self.app.get_event_handlers(event):
175+
handler_kwargs = self.dependency_provider.get_handler_kwargs(
176+
handler_func, **self.overrides
177+
)
178+
p = partial(handler_func, event, **handler_kwargs)
179+
wrapped_handler = self._wrap_with_middlewares(p, event=event)
180+
result = wrapped_handler()
181+
assert isinstance(
182+
result, EventResult
183+
), f"Got {result} instead of EventResult from {handler_func}"
184+
event_results.append(result)
185+
return EventResultSet(event_results)
186+
187+
def collect_integration_event(self, event):
188+
self.integration_events.append(event)
189+
166190
def get_service(self, service_cls):
167191
return self.dependency_provider.get_dependency(service_cls)
168192

169193
@property
170194
def current_user(self):
171195
return self.dependency_provider.get_dependency("current_user")
172196

173-
def _process_integration_event(self, event):
174-
self.integration_events.append(event)
175-
176-
def _process_domain_event(self, event):
177-
new_commands = []
178-
new_events = []
179-
for handler_func in self.app.get_event_handlers(event):
180-
handler_kwargs = self.dependency_provider.get_handler_kwargs(
181-
handler_func, **self.overrides
182-
)
183-
event_handler = partial(handler_func, event, **handler_kwargs)
184-
wrapped_handler = self._wrap_with_middlewares(event_handler, event=event)
185-
result = wrapped_handler()
186-
if isinstance(result, Command):
187-
new_commands.append(result)
188-
elif isinstance(result, EventResult):
189-
new_events.extend(result.events)
190-
return new_commands, new_events
191-
192197

193198
class ApplicationModule:
194199
def __init__(self, name, version=1.0):

src/seedwork/application/events.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class EventResult:
2424

2525
event_id: UUID = None
2626
payload: Any = None
27+
command: Any = None # command th
2728
events: list[DomainEvent] = field(default_factory=list)
2829
errors: list[Any] = field(default_factory=list)
2930

@@ -48,14 +49,14 @@ def failure(cls, message="Failure", exception=None) -> "CommandResult":
4849

4950
@classmethod
5051
def success(
51-
cls, event_id=None, payload=None, event=None, events=None
52+
cls, event_id=None, payload=None, command=None, event=None, events=None
5253
) -> "EventResult":
5354
"""Creates a successful result"""
5455
if events is None:
5556
events = []
5657
if event:
5758
events.append(event)
58-
return cls(event_id=event_id, payload=payload, events=events)
59+
return cls(event_id=event_id, payload=payload, command=command, events=events)
5960

6061

6162
class EventResultSet(set):
@@ -70,3 +71,8 @@ def events(self):
7071
for event in self:
7172
all_events.extend(event.events)
7273
return all_events
74+
75+
@property
76+
def commands(self):
77+
all_commands = [event.command for event in self if event.command]
78+
return all_commands

0 commit comments

Comments
 (0)