Skip to content

Commit 3fd0b08

Browse files
committed
Refactoring client and worker to pass around interceptors
Signed-off-by: Ryan Lettieri <[email protected]>
1 parent efc0146 commit 3fd0b08

File tree

5 files changed

+86
-35
lines changed

5 files changed

+86
-35
lines changed

durabletask/client.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import durabletask.internal.orchestrator_service_pb2 as pb
1616
import durabletask.internal.orchestrator_service_pb2_grpc as stubs
1717
import durabletask.internal.shared as shared
18+
from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl
19+
1820
from durabletask import task
1921

2022
TInput = TypeVar('TInput')
@@ -96,8 +98,23 @@ def __init__(self, *,
9698
metadata: Optional[list[tuple[str, str]]] = None,
9799
log_handler: Optional[logging.Handler] = None,
98100
log_formatter: Optional[logging.Formatter] = None,
99-
secure_channel: bool = False):
100-
channel = shared.get_grpc_channel(host_address, metadata, secure_channel=secure_channel)
101+
secure_channel: bool = False,
102+
interceptors: Optional[list] = None):
103+
104+
# Determine the interceptors to use
105+
if interceptors is not None:
106+
self._interceptors = interceptors
107+
elif metadata:
108+
self._interceptors = [DefaultClientInterceptorImpl(metadata)]
109+
else:
110+
self._interceptors = None
111+
112+
channel = shared.get_grpc_channel(
113+
host_address=host_address,
114+
metadata=metadata,
115+
secure_channel=secure_channel,
116+
interceptors=self._interceptors
117+
)
101118
self._stub = stubs.TaskHubSidecarServiceStub(channel)
102119
self._logger = shared.get_logger("client", log_handler, log_formatter)
103120

durabletask/internal/shared.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,6 @@
99

1010
import grpc
1111

12-
from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl
13-
from externalpackages.durabletaskscheduler.durabletask_grpc_interceptor import DTSDefaultClientInterceptorImpl
14-
1512
# Field name used to indicate that an object was automatically serialized
1613
# and should be deserialized as a SimpleNamespace
1714
AUTO_SERIALIZED = "__durabletask_autoobject__"
@@ -26,8 +23,10 @@ def get_default_host_address() -> str:
2623

2724
def get_grpc_channel(
2825
host_address: Optional[str],
29-
metadata: Optional[list[tuple[str, str]]],
30-
secure_channel: bool = False) -> grpc.Channel:
26+
metadata: Optional[list[tuple[str, str]]] = None,
27+
secure_channel: bool = False,
28+
interceptors: Optional[list] = None) -> grpc.Channel:
29+
3130
if host_address is None:
3231
host_address = get_default_host_address()
3332

@@ -45,19 +44,14 @@ def get_grpc_channel(
4544
host_address = host_address[len(protocol):]
4645
break
4746

47+
# Create the base channel
4848
if secure_channel:
4949
channel = grpc.secure_channel(host_address, grpc.ssl_channel_credentials())
5050
else:
5151
channel = grpc.insecure_channel(host_address)
5252

53-
if metadata is not None and len(metadata) > 0:
54-
for key, _ in metadata:
55-
# Check if we are using DTS as the backend and if so, construct the DTS specific interceptors
56-
if key == "dts":
57-
interceptors = [DTSDefaultClientInterceptorImpl(metadata)]
58-
break
59-
else:
60-
interceptors = [DefaultClientInterceptorImpl(metadata)]
53+
# Apply interceptors ONLY if they exist
54+
if interceptors:
6155
channel = grpc.intercept_channel(channel, *interceptors)
6256
return channel
6357

durabletask/worker.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
import durabletask.internal.orchestrator_service_pb2 as pb
1717
import durabletask.internal.orchestrator_service_pb2_grpc as stubs
1818
import durabletask.internal.shared as shared
19+
1920
from durabletask import task
21+
from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl
2022

2123
TInput = TypeVar('TInput')
2224
TOutput = TypeVar('TOutput')
@@ -88,15 +90,25 @@ def __init__(self, *,
8890
metadata: Optional[list[tuple[str, str]]] = None,
8991
log_handler=None,
9092
log_formatter: Optional[logging.Formatter] = None,
91-
secure_channel: bool = False):
93+
secure_channel: bool = False,
94+
interceptors: Optional[list[grpc.ServerInterceptor]] = None): # Add interceptors
9295
self._registry = _Registry()
9396
self._host_address = host_address if host_address else shared.get_default_host_address()
94-
self._metadata = metadata
97+
self._metadata = metadata or [] # Ensure metadata is never None
9598
self._logger = shared.get_logger("worker", log_handler, log_formatter)
9699
self._shutdown = Event()
97100
self._is_running = False
98101
self._secure_channel = secure_channel
99102

103+
# Determine the interceptors to use
104+
if interceptors is not None:
105+
self._interceptors = interceptors
106+
elif self._metadata:
107+
self._interceptors = [DefaultClientInterceptorImpl(self._metadata)]
108+
else:
109+
self._interceptors = None
110+
111+
100112
def __enter__(self):
101113
return self
102114

@@ -117,7 +129,12 @@ def add_activity(self, fn: task.Activity) -> str:
117129

118130
def start(self):
119131
"""Starts the worker on a background thread and begins listening for work items."""
120-
channel = shared.get_grpc_channel(self._host_address, self._metadata, self._secure_channel)
132+
133+
if self._metadata:
134+
interceptors = [DefaultClientInterceptorImpl(self._metadata)]
135+
else:
136+
interceptors = None
137+
channel = shared.get_grpc_channel(self._host_address, self._metadata, self._secure_channel, interceptors)
121138
stub = stubs.TaskHubSidecarServiceStub(channel)
122139

123140
if self._is_running:

externalpackages/durabletaskscheduler/durabletask_scheduler_client.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,40 @@
44
from typing import Optional
55
from durabletask.client import TaskHubGrpcClient
66
from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager
7+
from externalpackages.durabletaskscheduler.durabletask_grpc_interceptor import DTSDefaultClientInterceptorImpl
78

9+
# Client class used for Durable Task Scheduler (DTS)
810
class DurableTaskSchedulerClient(TaskHubGrpcClient):
911
def __init__(self,
1012
host_address: str,
1113
secure_channel: bool,
12-
metadata: Optional[list[tuple[str, str]]] = None,
14+
metadata: Optional[list[tuple[str, str]]] = [],
1315
use_managed_identity: Optional[bool] = False,
1416
client_id: Optional[str] = None,
1517
taskhub: str = None,
1618
**kwargs):
17-
if metadata is None:
18-
metadata = [] # Ensure metadata is initialized
19-
self._metadata = metadata
20-
self._use_managed_identity = use_managed_identity
21-
self._client_id = client_id
22-
self._metadata.append(("taskhub", taskhub))
19+
20+
# Ensure metadata is a list
21+
metadata = metadata or []
22+
self._metadata = metadata.copy() # Use a copy to avoid modifying original
23+
24+
# Append DurableTask-specific metadata
25+
self._metadata.append(("taskhub", taskhub or "default-taskhub"))
2326
self._metadata.append(("dts", "True"))
2427
self._metadata.append(("use_managed_identity", str(use_managed_identity)))
25-
self._metadata.append(("client_id", str(client_id)))
28+
self._metadata.append(("client_id", str(client_id or "None")))
29+
2630
self._access_token_manager = AccessTokenManager(metadata=self._metadata)
2731
self.__update_metadata_with_token()
28-
super().__init__(host_address=host_address, secure_channel=secure_channel, metadata=self._metadata, **kwargs)
32+
interceptors = [DTSDefaultClientInterceptorImpl(self._metadata)]
33+
34+
super().__init__(
35+
host_address=host_address,
36+
secure_channel=secure_channel,
37+
metadata=self._metadata,
38+
interceptors=interceptors, # Now explicitly passing interceptors
39+
**kwargs
40+
)
2941

3042
def __update_metadata_with_token(self):
3143
"""

externalpackages/durabletaskscheduler/durabletask_scheduler_worker.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Optional
55
from durabletask.worker import TaskHubGrpcWorker
66
from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager
7+
from externalpackages.durabletaskscheduler.durabletask_grpc_interceptor import DTSDefaultClientInterceptorImpl
78

89
# Worker class used for Durable Task Scheduler (DTS)
910
class DurableTaskSchedulerWorker(TaskHubGrpcWorker):
@@ -15,18 +16,28 @@ def __init__(self,
1516
client_id: Optional[str] = None,
1617
taskhub: str = None,
1718
**kwargs):
18-
if metadata is None:
19-
metadata = [] # Ensure metadata is initialized
20-
self._metadata = metadata
21-
self._use_managed_identity = use_managed_identity
22-
self._client_id = client_id
23-
self._metadata.append(("taskhub", taskhub))
19+
20+
# Ensure metadata is a list
21+
metadata = metadata or []
22+
self._metadata = metadata.copy() # Copy to prevent modifying input
23+
24+
# Append DurableTask-specific metadata
25+
self._metadata.append(("taskhub", taskhub or "default-taskhub"))
2426
self._metadata.append(("dts", "True"))
2527
self._metadata.append(("use_managed_identity", str(use_managed_identity)))
26-
self._metadata.append(("client_id", str(client_id)))
28+
self._metadata.append(("client_id", str(client_id or "None")))
29+
2730
self._access_token_manager = AccessTokenManager(metadata=self._metadata)
2831
self.__update_metadata_with_token()
29-
super().__init__(host_address=host_address, secure_channel=secure_channel, metadata=self._metadata, **kwargs)
32+
interceptors = [DTSDefaultClientInterceptorImpl(self._metadata)]
33+
34+
super().__init__(
35+
host_address=host_address,
36+
secure_channel=secure_channel,
37+
metadata=self._metadata,
38+
interceptors=interceptors,
39+
**kwargs
40+
)
3041

3142
def __update_metadata_with_token(self):
3243
"""

0 commit comments

Comments
 (0)