Skip to content

Commit 6cceb15

Browse files
authored
Merge pull request #164 from mosquito/feafure/dataclasses
Conection stuck fixes
2 parents 0af4f63 + 5411880 commit 6cceb15

File tree

9 files changed

+841
-594
lines changed

9 files changed

+841
-594
lines changed

aiormq/abc.py

Lines changed: 65 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
import asyncio
2+
import dataclasses
3+
import io
4+
import logging
25
from abc import ABC, abstractmethod, abstractproperty
36
from types import TracebackType
47
from typing import (
5-
Any, Awaitable, Callable, Coroutine, Dict, Iterable, NamedTuple, Optional,
6-
Set, Tuple, Type, Union,
8+
Any, Awaitable, Callable, Coroutine, Dict, Iterable, Optional, Set, Tuple,
9+
Type, Union,
710
)
811

12+
import pamqp
913
from pamqp import commands as spec
1014
from pamqp.base import Frame
1115
from pamqp.body import ContentBody
@@ -22,24 +26,24 @@
2226

2327
# noinspection PyShadowingNames
2428
class TaskWrapper:
25-
__slots__ = "exception", "task"
29+
__slots__ = "_exception", "task"
2630

27-
exception: Union[BaseException, Type[BaseException]]
31+
_exception: Union[BaseException, Type[BaseException]]
2832
task: asyncio.Task
2933

3034
def __init__(self, task: asyncio.Task):
3135
self.task = task
32-
self.exception = asyncio.CancelledError
36+
self._exception = asyncio.CancelledError
3337

3438
def throw(self, exception: ExceptionType) -> None:
35-
self.exception = exception
39+
self._exception = exception
3640
self.task.cancel()
3741

3842
async def __inner(self) -> Any:
3943
try:
4044
return await self.task
4145
except asyncio.CancelledError as e:
42-
raise self.exception from e
46+
raise self._exception from e
4347

4448
def __await__(self, *args: Any, **kwargs: Any) -> Any:
4549
return self.__inner().__await__()
@@ -59,7 +63,8 @@ def __repr__(self) -> str:
5963
GetResultType = Union[Basic.GetEmpty, Basic.GetOk]
6064

6165

62-
class DeliveredMessage(NamedTuple):
66+
@dataclasses.dataclass(frozen=True)
67+
class DeliveredMessage:
6368
delivery: Union[spec.Basic.Deliver, spec.Basic.Return, GetResultType]
6469
header: ContentHeader
6570
body: bytes
@@ -137,7 +142,8 @@ def message_count(self) -> Optional[int]:
137142
]
138143

139144

140-
class SSLCerts(NamedTuple):
145+
@dataclasses.dataclass(frozen=True)
146+
class SSLCerts:
141147
cert: Optional[str]
142148
key: Optional[str]
143149
capath: Optional[str]
@@ -146,7 +152,8 @@ class SSLCerts(NamedTuple):
146152
verify: bool
147153

148154

149-
class FrameReceived(NamedTuple):
155+
@dataclasses.dataclass(frozen=True)
156+
class FrameReceived:
150157
channel: int
151158
frame: str
152159

@@ -182,11 +189,56 @@ class FrameReceived(NamedTuple):
182189
]
183190

184191

185-
class ChannelFrame(NamedTuple):
186-
channel_number: int
187-
frames: Iterable[Union[FrameType, Heartbeat, ContentBody]]
192+
@dataclasses.dataclass(frozen=True)
193+
class ChannelFrame:
194+
payload: bytes
195+
should_close: bool
188196
drain_future: Optional[asyncio.Future] = None
189197

198+
def drain(self) -> None:
199+
if not self.should_drain:
200+
return
201+
202+
if self.drain_future is not None:
203+
self.drain_future.set_result(None)
204+
205+
@property
206+
def should_drain(self) -> bool:
207+
return self.drain_future is not None and not self.drain_future.done()
208+
209+
@classmethod
210+
def marshall(
211+
cls, channel_number: int,
212+
frames: Iterable[Union[FrameType, Heartbeat, ContentBody]],
213+
drain_future: Optional[asyncio.Future] = None,
214+
) -> "ChannelFrame":
215+
should_close = False
216+
217+
with io.BytesIO() as fp:
218+
for frame in frames:
219+
if should_close:
220+
logger = logging.getLogger(
221+
"aiormq.connection"
222+
).getChild(
223+
"marshall"
224+
)
225+
226+
logger.warning(
227+
"It looks like you are going to send a frame %r after "
228+
"the connection is closed, it's pointless, "
229+
"the frame is dropped.", frame,
230+
)
231+
continue
232+
if isinstance(frame, spec.Connection.CloseOk):
233+
should_close = True
234+
fp.write(pamqp.frame.marshal(frame, channel_number))
235+
236+
return cls(
237+
payload=fp.getvalue(),
238+
drain_future=drain_future,
239+
should_close=should_close,
240+
)
241+
190242

191243
class AbstractFutureStore:
192244
futures: Set[Union[asyncio.Future, TaskType]]

aiormq/channel.py

Lines changed: 64 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
import io
33
import logging
44
from collections import OrderedDict
5-
from contextlib import suppress
65
from functools import partial
76
from io import BytesIO
87
from random import getrandbits
98
from types import MappingProxyType
10-
from typing import Any, Dict, List, Mapping, Optional, Set, Type, Union
9+
from typing import (
10+
Any, Awaitable, Callable, Dict, List, Mapping, Optional, Set, Tuple, Type,
11+
Union,
12+
)
1113
from uuid import UUID
1214

1315
import pamqp.frame
@@ -166,7 +168,7 @@ async def rpc(
166168
try:
167169
await countdown(
168170
self.write_queue.put(
169-
ChannelFrame(
171+
ChannelFrame.marshall(
170172
channel_number=self.number,
171173
frames=[frame],
172174
),
@@ -195,7 +197,7 @@ async def rpc(
195197

196198
self.__close_event.set()
197199
self.write_queue.put_nowait(
198-
ChannelFrame(
200+
ChannelFrame.marshall(
199201
channel_number=self.number,
200202
frames=[
201203
spec.Channel.Close(
@@ -268,7 +270,7 @@ async def __get_content_header(self) -> ContentHeader:
268270

269271
return frame
270272

271-
async def _on_deliver(self, frame: spec.Basic.Deliver) -> None:
273+
async def _on_deliver_frame(self, frame: spec.Basic.Deliver) -> None:
272274
header: ContentHeader = await self.__get_content_header()
273275
message = await self._read_content(frame, header)
274276

@@ -281,7 +283,7 @@ async def _on_deliver(self, frame: spec.Basic.Deliver) -> None:
281283
# noinspection PyAsyncCall
282284
self.create_task(consumer(message))
283285

284-
async def _on_get(
286+
async def _on_get_frame(
285287
self, frame: Union[spec.Basic.GetOk, spec.Basic.GetEmpty],
286288
) -> None:
287289
message = None
@@ -308,7 +310,7 @@ async def _on_get(
308310
getter.set_result((frame, message))
309311
return
310312

311-
async def _on_return(self, frame: spec.Basic.Return) -> None:
313+
async def _on_return_frame(self, frame: spec.Basic.Return) -> None:
312314
header: ContentHeader = await self.__get_content_header()
313315
message = await self._read_content(frame, header)
314316
message_id = message.header.properties.message_id
@@ -366,7 +368,7 @@ def _confirm_delivery(
366368
DeliveryError(None, frame),
367369
) # pragma: nocover
368370

369-
async def _on_confirm(self, frame: ConfirmationFrameType) -> None:
371+
async def _on_confirm_frame(self, frame: ConfirmationFrameType) -> None:
370372
if not self.publisher_confirms: # pragma: nocover
371373
return
372374

@@ -386,61 +388,59 @@ async def _on_confirm(self, frame: ConfirmationFrameType) -> None:
386388
else:
387389
self._confirm_delivery(frame.delivery_tag, frame)
388390

391+
async def _on_cancel_frame(
392+
self,
393+
frame: Union[spec.Basic.CancelOk, spec.Basic.Cancel],
394+
) -> None:
395+
if frame.consumer_tag is not None:
396+
self.consumers.pop(frame.consumer_tag, None)
397+
398+
async def _on_close_frame(self, frame: spec.Channel.Close) -> None:
399+
exc: BaseException = exception_by_code(frame)
400+
self.write_queue.put_nowait(
401+
ChannelFrame.marshall(
402+
channel_number=self.number,
403+
frames=[spec.Channel.CloseOk()],
404+
),
405+
)
406+
self.connection.channels.pop(self.number, None)
407+
raise exc
408+
409+
async def _on_close_ok_frame(self, _: spec.Channel.CloseOk) -> None:
410+
self.connection.channels.pop(self.number, None)
411+
raise ChannelClosed()
412+
389413
async def _reader(self) -> None:
390-
while True:
391-
try:
414+
hooks: Mapping[Any, Tuple[bool, Callable[[Any], Awaitable[None]]]]
415+
416+
hooks = {
417+
spec.Basic.Deliver: (False, self._on_deliver_frame),
418+
spec.Basic.GetOk: (True, self._on_get_frame),
419+
spec.Basic.GetEmpty: (True, self._on_get_frame),
420+
spec.Basic.Return: (False, self._on_return_frame),
421+
spec.Basic.Cancel: (False, self._on_cancel_frame),
422+
spec.Basic.CancelOk: (True, self._on_cancel_frame),
423+
spec.Channel.Close: (False, self._on_close_frame),
424+
spec.Channel.CloseOk: (False, self._on_close_ok_frame),
425+
spec.Basic.Ack: (False, self._on_confirm_frame),
426+
spec.Basic.Nack: (False, self._on_confirm_frame),
427+
}
428+
429+
try:
430+
while True:
392431
frame = await self._get_frame()
432+
should_add_to_rpc, hook = hooks.get(type(frame), (True, None))
393433

394-
if isinstance(frame, spec.Basic.Deliver):
395-
with suppress(Exception):
396-
await self._on_deliver(frame)
397-
continue
398-
elif isinstance(
399-
frame, (spec.Basic.GetOk, spec.Basic.GetEmpty),
400-
):
401-
with suppress(Exception):
402-
await self._on_get(frame)
403-
elif isinstance(frame, spec.Basic.Return):
404-
with suppress(Exception):
405-
await self._on_return(frame)
406-
continue
407-
elif isinstance(frame, spec.Basic.Cancel):
408-
if frame.consumer_tag is None:
409-
continue
410-
self.consumers.pop(frame.consumer_tag, None)
411-
continue
412-
elif isinstance(frame, spec.Basic.CancelOk):
413-
if frame.consumer_tag is not None:
414-
self.consumers.pop(frame.consumer_tag, None)
415-
elif isinstance(frame, (spec.Basic.Ack, spec.Basic.Nack)):
416-
with suppress(Exception):
417-
await self._on_confirm(frame)
418-
continue
419-
elif isinstance(frame, spec.Channel.Close):
420-
exc: BaseException = exception_by_code(frame)
421-
self.write_queue.put_nowait(
422-
ChannelFrame(
423-
channel_number=self.number,
424-
frames=[spec.Channel.CloseOk()],
425-
),
426-
)
427-
428-
self.connection.channels.pop(self.number, None)
429-
await self._cancel_tasks(exc)
430-
return
431-
elif isinstance(frame, spec.Channel.CloseOk):
432-
if self.__close_event.is_set():
433-
await self._cancel_tasks(asyncio.TimeoutError())
434-
self.connection.channels.pop(self.number, None)
434+
if hook is not None:
435+
try:
436+
await hook(frame)
437+
except asyncio.CancelledError as e:
438+
await self._cancel_tasks(e)
435439
return
436-
437-
await self.rpc_frames.put(frame)
438-
except asyncio.CancelledError:
439-
return
440-
except Exception as e: # pragma: nocover
441-
log.debug("Channel reader exception %r", exc_info=e)
442-
await self._cancel_tasks(e)
443-
raise
440+
if should_add_to_rpc:
441+
await self.rpc_frames.put(frame)
442+
except Exception as e:
443+
await self._cancel_tasks(e)
444444

445445
@task
446446
async def _on_close(self, exc: Optional[ExceptionType] = None) -> None:
@@ -454,6 +454,7 @@ async def _on_close(self, exc: Optional[ExceptionType] = None) -> None:
454454
timeout=self.connection.connection_tune.heartbeat or None,
455455
)
456456
self.connection.channels.pop(self.number, None)
457+
self.__close_event.set()
457458

458459
async def basic_get(
459460
self, queue: str = "", no_ack: bool = False,
@@ -525,7 +526,7 @@ async def basic_ack(
525526
drain_future = self.create_future() if wait else None
526527

527528
await self.write_queue.put(
528-
ChannelFrame(
529+
ChannelFrame.marshall(
529530
frames=[
530531
spec.Basic.Ack(
531532
delivery_tag=delivery_tag,
@@ -553,7 +554,7 @@ async def basic_nack(
553554
drain_future = self.create_future() if wait else None
554555

555556
await self.write_queue.put(
556-
ChannelFrame(
557+
ChannelFrame.marshall(
557558
frames=[
558559
spec.Basic.Nack(
559560
delivery_tag=delivery_tag,
@@ -574,7 +575,7 @@ async def basic_reject(
574575
) -> None:
575576
drain_future = self.create_future()
576577
await self.write_queue.put(
577-
ChannelFrame(
578+
ChannelFrame.marshall(
578579
channel_number=self.number,
579580
frames=[
580581
spec.Basic.Reject(
@@ -664,7 +665,7 @@ async def basic_publish(
664665

665666
await countdown(
666667
self.write_queue.put(
667-
ChannelFrame(
668+
ChannelFrame.marshall(
668669
frames=body_frames,
669670
channel_number=self.number,
670671
drain_future=drain_future,

0 commit comments

Comments
 (0)