Skip to content

Commit 1315e38

Browse files
authored
Merge branch 'fork' into deab-savable-inh
2 parents 845a7d6 + b485d7d commit 1315e38

File tree

2 files changed

+65
-10
lines changed

2 files changed

+65
-10
lines changed

src/plumpy/coordinator.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# -*- coding: utf-8 -*-
22
from __future__ import annotations
33

4-
from typing import TYPE_CHECKING, Any, Callable, Hashable, Pattern, Protocol
4+
import re
5+
from typing import TYPE_CHECKING, Any, Callable, Hashable, Protocol
56

67
if TYPE_CHECKING:
78
# identifiers for subscribers
@@ -23,8 +24,8 @@ def add_rpc_subscriber(self, subscriber: 'RpcSubscriber', identifier: 'ID_TYPE |
2324
def add_broadcast_subscriber(
2425
self,
2526
subscriber: 'BroadcastSubscriber',
26-
subject_filters: list[Hashable | Pattern[str]] | None = None,
27-
sender_filters: list[Hashable | Pattern[str]] | None = None,
27+
subject_filters: list[Hashable | re.Pattern[str]] | None = None,
28+
sender_filters: list[Hashable | re.Pattern[str]] | None = None,
2829
identifier: 'ID_TYPE | None' = None,
2930
) -> Any: ...
3031

@@ -50,3 +51,58 @@ def broadcast_send(
5051
def task_send(self, task: Any, no_reply: bool = False) -> Any: ...
5152

5253
def close(self) -> None: ...
54+
55+
56+
class BroadcastFilter:
57+
"""A filter that can be used to limit the subjects and/or senders that will be received"""
58+
59+
def __init__(self, subscriber, subject=None, sender=None): # type: ignore
60+
self._subscriber = subscriber
61+
self._subject_filters = []
62+
self._sender_filters = []
63+
if subject is not None:
64+
self.add_subject_filter(subject)
65+
if sender is not None:
66+
self.add_sender_filter(sender)
67+
68+
@property
69+
def __name__(self): # type: ignore
70+
return 'BroadcastFilter'
71+
72+
def __call__(self, communicator, body, sender=None, subject=None, correlation_id=None): # type: ignore
73+
if self.is_filtered(sender, subject):
74+
return None
75+
return self._subscriber(communicator, body, sender, subject, correlation_id)
76+
77+
def is_filtered(self, sender, subject) -> bool: # type: ignore
78+
if subject is not None and self._subject_filters and not any(check(subject) for check in self._subject_filters):
79+
return True
80+
81+
if sender is not None and self._sender_filters and not any(check(sender) for check in self._sender_filters):
82+
return True
83+
84+
return False
85+
86+
def add_subject_filter(self, subject_filter: re.Pattern[str] | None) -> None:
87+
self._subject_filters.append(self._ensure_filter(subject_filter)) # type: ignore
88+
89+
def add_sender_filter(self, sender_filter: re.Pattern[str]) -> None:
90+
self._sender_filters.append(self._ensure_filter(sender_filter)) # type: ignore
91+
92+
@classmethod
93+
def _ensure_filter(cls, filter_value): # type: ignore
94+
if isinstance(filter_value, str):
95+
return re.compile(filter_value.replace('.', '[.]').replace('*', '.*')).match
96+
if isinstance(filter_value, re.Pattern): # pylint: disable=isinstance-second-argument-not-valid-type
97+
return filter_value.match
98+
99+
return lambda val: val == filter_value
100+
101+
@classmethod
102+
def _make_regex(cls, filter_str): # type: ignore
103+
"""
104+
:param filter_str: The filter string
105+
:type filter_str: str
106+
:return: The regular expression object
107+
"""
108+
return re.compile(filter_str.replace('.', '[.]'))

src/plumpy/processes.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,9 @@
3535
cast,
3636
)
3737

38-
import kiwipy
3938

40-
from plumpy.coordinator import Coordinator
4139
from plumpy.persistence import ensure_object_loader
40+
from plumpy.coordinator import BroadcastFilter, Coordinator
4241

4342
try:
4443
from aiocontextvars import ContextVar
@@ -390,12 +389,12 @@ def init(self) -> None:
390389

391390
try:
392391
# filter out state change broadcasts
393-
# XXX: remove dep on kiwipy
394-
subscriber = kiwipy.BroadcastFilter(self.broadcast_receive, subject=re.compile(r'^(?!state_changed).*'))
392+
subscriber = BroadcastFilter( # type: ignore
393+
self.broadcast_receive,
394+
subject=re.compile(r'^(?!state_changed).*'),
395+
)
395396
identifier = self._coordinator.add_broadcast_subscriber(subscriber, identifier=str(self.pid))
396-
# identifier = self._coordinator.add_broadcast_subscriber(
397-
# subscriber, subject_filters=[re.compile(r'^(?!state_changed).*')], identifier=str(self.pid)
398-
# )
397+
399398
self.add_cleanup(functools.partial(self._coordinator.remove_broadcast_subscriber, identifier))
400399
except concurrent.futures.TimeoutError:
401400
self.logger.exception('Process<%s>: failed to register as a broadcast subscriber', self.pid)

0 commit comments

Comments
 (0)