Skip to content

Commit 986d5ea

Browse files
committed
Merge branch 'main' into tracer-middleware
2 parents daf5746 + 2486aae commit 986d5ea

File tree

8 files changed

+27
-17
lines changed

8 files changed

+27
-17
lines changed

src/workflows/contrib/start_service.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from __future__ import annotations
22

33
import sys
4+
from collections.abc import Callable
45
from optparse import SUPPRESS_HELP, OptionParser
56

67
import workflows
78
import workflows.frontend
89
import workflows.services
910
import workflows.transport
11+
from workflows.transport.common_transport import CommonTransport
1012

1113

1214
class ServiceStarter:
@@ -27,12 +29,14 @@ def on_parsing(options, args):
2729
"""
2830

2931
@staticmethod
30-
def on_transport_factory_preparation(transport_factory):
32+
def on_transport_factory_preparation(
33+
transport_factory,
34+
) -> Callable[[], CommonTransport] | None:
3135
"""Plugin hook to intercept/manipulate newly created Transport factories
3236
before first invocation."""
3337

3438
@staticmethod
35-
def on_transport_preparation(transport):
39+
def on_transport_preparation(transport: CommonTransport) -> CommonTransport | None:
3640
"""Plugin hook to intercept/manipulate newly created Transport objects
3741
before connecting."""
3842

@@ -136,7 +140,9 @@ def run(
136140
parser.error(f"Please specify a service name. {known_services_help}")
137141

138142
# Create Transport factory
139-
transport_factory = workflows.transport.lookup(options.transport)
143+
transport_factory: Callable[[], CommonTransport] = workflows.transport.lookup(
144+
options.transport
145+
)
140146

141147
# Call on_transport_factory_preparation hook
142148
transport_factory = (
@@ -147,7 +153,7 @@ def run(
147153
# Set up on_transport_preparation hook to affect newly created transport objects
148154
true_transport_factory = transport_factory
149155

150-
def on_transport_preparation_hook():
156+
def on_transport_preparation_hook() -> CommonTransport:
151157
transport_object = true_transport_factory()
152158
return self.on_transport_preparation(transport_object) or transport_object
153159

src/workflows/frontend/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
import multiprocessing
55
import threading
66
import time
7+
from collections.abc import Callable
78

89
import workflows
910
import workflows.frontend.utilization
1011
import workflows.services
1112
import workflows.transport
1213
import workflows.util
1314
from workflows.services.common_service import CommonService
15+
from workflows.transport.common_transport import CommonTransport
1416

1517
basestring = (str, bytes)
1618

@@ -22,9 +24,11 @@ class Frontend:
2224
service.
2325
"""
2426

27+
_transport_factory: Callable[[], CommonTransport]
28+
2529
def __init__(
2630
self,
27-
transport=None,
31+
transport: Callable[[], CommonTransport] | str | None = None,
2832
service=None,
2933
transport_command_channel=None,
3034
restart_service=False,
@@ -98,7 +102,6 @@ def __getitem__(self, key):
98102
self.log = logging.LoggerAdapter(
99103
logging.getLogger("workflows.frontend"), LogAdapter()
100104
)
101-
self.log.warn = self.log.warning # Add support for deprecated .warn
102105

103106
# Connect to the network transport layer
104107
if transport is None or isinstance(transport, basestring):

src/workflows/transport/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
default_transport = "PikaTransport"
1212

1313

14-
def lookup(transport: str) -> type[CommonTransport]:
14+
def lookup(transport: str | None) -> type[CommonTransport]:
1515
"""Get a transport layer class based on its name."""
16-
return get_known_transports().get(
17-
transport, get_known_transports()[default_transport]
18-
)
16+
known_transports = get_known_transports()
17+
if transport not in known_transports:
18+
return known_transports[default_transport]
19+
return known_transports[transport]
1920

2021

2122
def add_command_line_options(

src/workflows/transport/common_transport.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class CommonTransport:
3333
#
3434

3535
def __init__(
36-
self, middleware: list[type[middleware.BaseTransportMiddleware]] = None
36+
self, middleware: list[type[middleware.BaseTransportMiddleware]] | None = None
3737
):
3838
if middleware is None:
3939
self.middleware = []

src/workflows/transport/middleware/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,12 @@ def transaction_begin(
9595
return call_next(subscription_id=subscription_id, **kwargs)
9696

9797
def transaction_abort(
98-
self, call_next: Callable, transaction_id: int = None, **kwargs
98+
self, call_next: Callable, transaction_id: int | None = None, **kwargs
9999
):
100100
call_next(transaction_id, **kwargs)
101101

102102
def transaction_commit(
103-
self, call_next: Callable, transaction_id: int = None, **kwargs
103+
self, call_next: Callable, transaction_id: int | None = None, **kwargs
104104
):
105105
call_next(transaction_id, **kwargs)
106106

@@ -172,7 +172,7 @@ def transaction_commit(self, call_next: Callable, *args, **kwargs):
172172

173173

174174
class TimerMiddleware(BaseTransportMiddleware):
175-
def __init__(self, logger: logging.Logger = None, level=logging.INFO):
175+
def __init__(self, logger: logging.Logger | None = None, level=logging.INFO):
176176
if logger is None:
177177
logger = logging.getLogger(__name__)
178178
self.logger = logger

src/workflows/transport/offline_transport.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class OfflineTransport(CommonTransport):
2828
config: dict[Any, Any] = {}
2929

3030
def __init__(
31-
self, middleware: list[type[middleware.BaseTransportMiddleware]] = None
31+
self, middleware: list[type[middleware.BaseTransportMiddleware]] | None = None
3232
):
3333
self._connected = False
3434
super().__init__(middleware=middleware)

src/workflows/transport/pika_transport.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class PikaTransport(CommonTransport):
6363
config: dict[Any, Any] = {}
6464

6565
def __init__(
66-
self, middleware: list[type[middleware.BaseTransportMiddleware]] = None
66+
self, middleware: list[type[middleware.BaseTransportMiddleware]] | None = None
6767
):
6868
self._channel = None
6969
self._conn = None

src/workflows/transport/stomp_transport.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class StompTransport(CommonTransport):
3434
config: dict[Any, Any] = {}
3535

3636
def __init__(
37-
self, middleware: list[type[middleware.BaseTransportMiddleware]] = None
37+
self, middleware: list[type[middleware.BaseTransportMiddleware]] | None = None
3838
):
3939
self._connected = False
4040
self._namespace = ""

0 commit comments

Comments
 (0)