Skip to content

Commit 441bafb

Browse files
committed
Add timeout functionality to IntersectClient.
1 parent 6d0789e commit 441bafb

File tree

4 files changed

+295
-1
lines changed

4 files changed

+295
-1
lines changed

src/intersect_sdk/client.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from __future__ import annotations
1414

1515
import time
16+
from collections import defaultdict
17+
from threading import Event, Thread
1618
from typing import TYPE_CHECKING
1719
from uuid import uuid4
1820

@@ -47,6 +49,7 @@
4749
from .client_callback_definitions import (
4850
INTERSECT_CLIENT_EVENT_CALLBACK_TYPE,
4951
INTERSECT_CLIENT_RESPONSE_CALLBACK_TYPE,
52+
INTERSECT_CLIENT_TIMEOUT_CALLBACK_TYPE,
5053
)
5154
from .shared_callback_definitions import IntersectDirectMessageParams
5255

@@ -71,6 +74,7 @@ def __init__(
7174
config: IntersectClientConfig,
7275
user_callback: INTERSECT_CLIENT_RESPONSE_CALLBACK_TYPE | None = None,
7376
event_callback: INTERSECT_CLIENT_EVENT_CALLBACK_TYPE | None = None,
77+
timeout_callback: INTERSECT_CLIENT_TIMEOUT_CALLBACK_TYPE | None = None,
7478
) -> None:
7579
"""The constructor performs almost all validation checks necessary to function in the INTERSECT ecosystem, with the exception of checking connections/credentials to any backing services.
7680
@@ -79,6 +83,7 @@ def __init__(
7983
user_callback: The callback function you can use to handle response messages from Services.
8084
If this is left empty, you can only send a single message
8185
event_callback: The callback function you can use to handle events from any Service.
86+
timeout_callback: The callback function you can use to handle request timeouts.
8287
"""
8388
# this is called here in case a user created the object using "IntersectClientConfig.model_construct()" to skip validation
8489
config = IntersectClientConfig.model_validate(config)
@@ -87,6 +92,8 @@ def __init__(
8792
die('user_callback function should be a callable function if defined')
8893
if event_callback is not None and not callable(event_callback):
8994
die('event_callback function should be a callable function if defined')
95+
if timeout_callback is not None and not callable(timeout_callback):
96+
die('timeout_callback function should be a callable function if defined')
9097
if not user_callback and not event_callback:
9198
die('must define at least one of user_callback or event_callback')
9299
if not user_callback:
@@ -146,6 +153,10 @@ def __init__(
146153
)
147154
self._user_callback = user_callback
148155
self._event_callback = event_callback
156+
self._timeout_callback = timeout_callback
157+
self._pending_requests: defaultdict[str, list] = defaultdict(list)
158+
self._stop_timeout_thread = Event()
159+
self._timeout_thread = Thread(target=self._check_timeouts, daemon=True)
149160

150161
@final
151162
def startup(self) -> Self:
@@ -172,6 +183,8 @@ def startup(self) -> Self:
172183
# and has nothing to do with the Service at all.
173184
time.sleep(1.0)
174185

186+
self._timeout_thread.start()
187+
175188
if self._resend_initial_messages or not self._sent_initial_messages:
176189
for message in self._initial_messages:
177190
self._send_userspace_message(message)
@@ -200,11 +213,29 @@ def shutdown(self, reason: str | None = None) -> Self:
200213
"""
201214
logger.info(f'Client is shutting down (reason: {reason})')
202215

216+
self._stop_timeout_thread.set()
217+
self._timeout_thread.join()
203218
self._control_plane_manager.disconnect()
204219

205220
logger.info('Client shutdown complete')
206221
return self
207222

223+
def _check_timeouts(self) -> None:
224+
"""Periodically check for timed out requests."""
225+
while not self._stop_timeout_thread.is_set():
226+
now = time.time()
227+
for operation_id, requests in list(self._pending_requests.items()):
228+
for request in requests:
229+
if now > request['timeout']:
230+
try:
231+
request['on_timeout'](operation_id)
232+
except Exception as e:
233+
logger.warning(f'Exception from timeout callback for operation {operation_id}:\n{e}')
234+
requests.remove(request)
235+
if not requests:
236+
del self._pending_requests[operation_id]
237+
time.sleep(0.1) # Sleep for a short duration
238+
208239
@final
209240
def is_connected(self) -> bool:
210241
"""Check if we're currently connected to the INTERSECT brokers.
@@ -258,6 +289,13 @@ def _handle_userspace_message(self, message: UserspaceMessage) -> None:
258289
send_os_signal()
259290
return
260291

292+
# If not in pending requests, it already timed out, so ignore this response
293+
if message['operationId'] in self._pending_requests:
294+
del self._pending_requests[message['operationId']]
295+
else:
296+
logger.debug(f'Received response for operation {message["operationId"]} that already timed out, ignoring')
297+
return
298+
261299
# TWO: GET DATA FROM APPROPRIATE DATA STORE AND DESERIALIZE IT
262300
try:
263301
request_params = GENERIC_MESSAGE_SERIALIZER.validate_json(
@@ -436,3 +474,11 @@ def _send_userspace_message(self, params: IntersectDirectMessageParams) -> None:
436474
# but cannot communicate the response to the Client.
437475
# in experiment controllers or production, you'll want to set persist to True
438476
self._control_plane_manager.publish_message(channel, msg, persist=False)
477+
478+
if params.timeout is not None and params.on_timeout is not None:
479+
self._pending_requests[params.operation].append(
480+
{
481+
'timeout': time.time() + params.timeout,
482+
'on_timeout': params.on_timeout,
483+
}
484+
)

src/intersect_sdk/client_callback_definitions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,16 @@
1212
from .shared_callback_definitions import INTERSECT_JSON_VALUE, IntersectDirectMessageParams
1313

1414

15+
INTERSECT_CLIENT_TIMEOUT_CALLBACK_TYPE = Callable[[str], None]
16+
"""
17+
This is a callable function type which should be defined by the user.
18+
19+
Params
20+
The SDK will send the function one argument:
21+
1) The operation ID of the request that timed out.
22+
"""
23+
24+
1525
@final
1626
class IntersectClientCallback(BaseModel):
1727
"""The value a user should return from ALL client callback functions.

src/intersect_sdk/shared_callback_definitions.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Callback definitions shared between Services, Capabilities, and Clients."""
22

3-
from typing import Any, Dict, List, Union
3+
from typing import Any, Callable, Dict, List, Optional, Union
44

55
from pydantic import BaseModel, ConfigDict, Field
66
from typing_extensions import Annotated, TypeAlias
@@ -65,3 +65,13 @@ class IntersectDirectMessageParams(BaseModel):
6565

6666
# pydantic config
6767
model_config = ConfigDict(revalidate_instances='always')
68+
69+
timeout: Optional[float] = None
70+
"""
71+
The timeout in seconds for the request. If the request is not fulfilled within this time, the on_timeout callback will be called.
72+
"""
73+
74+
on_timeout: Optional[Callable[[], None]] = None
75+
"""
76+
The callback to call if the request times out.
77+
"""

tests/unit/test_client.py

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
from __future__ import annotations
2+
3+
import time
4+
from threading import Event, Thread
5+
from unittest.mock import MagicMock, patch
6+
7+
import pytest
8+
9+
from intersect_sdk.client import IntersectClient
10+
from intersect_sdk.client_callback_definitions import IntersectClientCallback
11+
from intersect_sdk.config.client import IntersectClientConfig
12+
from intersect_sdk.shared_callback_definitions import IntersectDirectMessageParams
13+
14+
15+
def test_timeout_callback_is_called():
16+
"""Tests that the timeout callback is called when a request times out."""
17+
config = IntersectClientConfig(
18+
system='test',
19+
facility='test',
20+
organization='test',
21+
brokers=[{'host': 'localhost', 'port': 1883, 'protocol': 'mqtt3.1.1', 'username': 'test', 'password': 'test'}],
22+
initial_message_event_config=IntersectClientCallback(),
23+
terminate_after_initial_messages=True,
24+
)
25+
26+
timeout_called = []
27+
28+
def on_timeout(operation_id):
29+
timeout_called.append(operation_id)
30+
31+
def user_callback(source, operation_id, has_error, payload):
32+
pass
33+
34+
client = IntersectClient(config, user_callback=user_callback)
35+
36+
# Mock the control plane and data plane managers
37+
client._control_plane_manager = MagicMock()
38+
client._control_plane_manager.is_connected.return_value = True
39+
client._control_plane_manager.considered_unrecoverable.return_value = False
40+
client._data_plane_manager = MagicMock()
41+
client._data_plane_manager.outgoing_message_data_handler.return_value = b'test'
42+
43+
# Manually start the timeout thread (since startup() would try to connect)
44+
client._stop_timeout_thread = Event()
45+
client._timeout_thread = Thread(target=client._check_timeouts, daemon=True)
46+
client._timeout_thread.start()
47+
48+
message = IntersectDirectMessageParams(
49+
destination='test.test.test.test.test',
50+
operation='test_op',
51+
payload='test',
52+
timeout=0.1,
53+
on_timeout=on_timeout,
54+
)
55+
56+
client._send_userspace_message(message)
57+
58+
# Wait for timeout to trigger
59+
time.sleep(0.3)
60+
61+
assert len(timeout_called) == 1
62+
assert timeout_called[0] == 'test_op'
63+
64+
# Clean up
65+
client._stop_timeout_thread.set()
66+
client._timeout_thread.join()
67+
68+
69+
def test_timeout_callback_is_not_called():
70+
"""Tests that the timeout callback is not called when a request is fulfilled."""
71+
config = IntersectClientConfig(
72+
system='test',
73+
facility='test',
74+
organization='test',
75+
brokers=[{'host': 'localhost', 'port': 1883, 'protocol': 'mqtt3.1.1', 'username': 'test', 'password': 'test'}],
76+
initial_message_event_config=IntersectClientCallback(),
77+
terminate_after_initial_messages=True,
78+
)
79+
80+
timeout_called = []
81+
82+
def on_timeout(operation_id):
83+
timeout_called.append(operation_id)
84+
85+
user_callback_called = []
86+
87+
def user_callback(source, operation_id, has_error, payload):
88+
user_callback_called.append(True)
89+
return None
90+
91+
client = IntersectClient(config, user_callback=user_callback)
92+
93+
# Mock the control plane and data plane managers
94+
client._control_plane_manager = MagicMock()
95+
client._control_plane_manager.is_connected.return_value = True
96+
client._control_plane_manager.considered_unrecoverable.return_value = False
97+
client._data_plane_manager = MagicMock()
98+
client._data_plane_manager.outgoing_message_data_handler.return_value = b'test'
99+
client._data_plane_manager.incoming_message_data_handler.return_value = b'"test"'
100+
101+
# Manually start the timeout thread
102+
client._stop_timeout_thread = Event()
103+
client._timeout_thread = Thread(target=client._check_timeouts, daemon=True)
104+
client._timeout_thread.start()
105+
106+
message = IntersectDirectMessageParams(
107+
destination='test.test.test.test.test',
108+
operation='test_op',
109+
payload='test',
110+
timeout=0.5,
111+
on_timeout=on_timeout,
112+
)
113+
114+
client._send_userspace_message(message)
115+
116+
# Simulate receiving a response before the timeout
117+
from datetime import datetime, timezone
118+
from uuid import uuid4
119+
120+
response_message = {
121+
'messageId': str(uuid4()),
122+
'headers': {
123+
'source': 'test.test.test.test.test',
124+
'destination': client._hierarchy.hierarchy_string('.'),
125+
'has_error': False,
126+
'sdk_version': '0.8.0',
127+
'created_at': datetime.now(timezone.utc),
128+
'data_handler': 0, # IntersectDataHandler.MESSAGE
129+
},
130+
'operationId': 'test_op',
131+
'payload': 'test',
132+
'contentType': 'application/json',
133+
}
134+
135+
client._handle_userspace_message(response_message)
136+
137+
# Wait to make sure timeout doesn't fire
138+
time.sleep(0.7)
139+
140+
assert len(timeout_called) == 0
141+
assert len(user_callback_called) == 1
142+
143+
# Clean up
144+
client._stop_timeout_thread.set()
145+
client._timeout_thread.join()
146+
147+
148+
def test_response_after_timeout_is_ignored():
149+
"""Tests that responses arriving after timeout are ignored and user_callback is not called."""
150+
config = IntersectClientConfig(
151+
system='test',
152+
facility='test',
153+
organization='test',
154+
brokers=[{'host': 'localhost', 'port': 1883, 'protocol': 'mqtt3.1.1', 'username': 'test', 'password': 'test'}],
155+
initial_message_event_config=IntersectClientCallback(),
156+
terminate_after_initial_messages=True,
157+
)
158+
159+
timeout_called = []
160+
161+
def on_timeout(operation_id):
162+
timeout_called.append(operation_id)
163+
164+
user_callback_called = []
165+
166+
def user_callback(source, operation_id, has_error, payload):
167+
user_callback_called.append(True)
168+
return None
169+
170+
client = IntersectClient(config, user_callback=user_callback)
171+
172+
# Mock the control plane and data plane managers
173+
client._control_plane_manager = MagicMock()
174+
client._control_plane_manager.is_connected.return_value = True
175+
client._control_plane_manager.considered_unrecoverable.return_value = False
176+
client._data_plane_manager = MagicMock()
177+
client._data_plane_manager.outgoing_message_data_handler.return_value = b'test'
178+
client._data_plane_manager.incoming_message_data_handler.return_value = b'"test"'
179+
180+
# Manually start the timeout thread
181+
client._stop_timeout_thread = Event()
182+
client._timeout_thread = Thread(target=client._check_timeouts, daemon=True)
183+
client._timeout_thread.start()
184+
185+
message = IntersectDirectMessageParams(
186+
destination='test.test.test.test.test',
187+
operation='test_op',
188+
payload='test',
189+
timeout=0.1,
190+
on_timeout=on_timeout,
191+
)
192+
193+
client._send_userspace_message(message)
194+
195+
# Wait for timeout to trigger
196+
time.sleep(0.3)
197+
198+
# Timeout should have been called
199+
assert len(timeout_called) == 1
200+
assert timeout_called[0] == 'test_op'
201+
202+
# Now simulate receiving a late response after timeout
203+
from datetime import datetime, timezone
204+
from uuid import uuid4
205+
206+
response_message = {
207+
'messageId': str(uuid4()),
208+
'headers': {
209+
'source': 'test.test.test.test.test',
210+
'destination': client._hierarchy.hierarchy_string('.'),
211+
'has_error': False,
212+
'sdk_version': '0.8.0',
213+
'created_at': datetime.now(timezone.utc),
214+
'data_handler': 0, # IntersectDataHandler.MESSAGE
215+
},
216+
'operationId': 'test_op',
217+
'payload': 'test',
218+
'contentType': 'application/json',
219+
}
220+
221+
client._handle_userspace_message(response_message)
222+
223+
# User callback should NOT have been called since the request already timed out
224+
assert len(user_callback_called) == 0
225+
226+
# Clean up
227+
client._stop_timeout_thread.set()
228+
client._timeout_thread.join()

0 commit comments

Comments
 (0)