Skip to content

Commit fe2dd8e

Browse files
committed
Use connect function as context manager
Sending a worklist item to the gateway is a brief exchange where we send a json payload and receive an acknowledgement or record a failure. The client connection we make to the gateway needn't be long lived for this simple exchange. This commit removes memoizing the connection on the RelayService class as this had no real effect. The connections would eventually close without a proper ping-pong configuration. Using a contextmanager pattern we ensure the connection is closed when the exchange is complete or an error occurs.
1 parent a2e050f commit fe2dd8e

File tree

2 files changed

+25
-57
lines changed

2 files changed

+25
-57
lines changed

manage_breast_screening/gateway/relay_service.py

Lines changed: 19 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import urllib.parse
99
from datetime import datetime, timezone
1010

11-
from websockets.asyncio.client import ClientConnection, connect
11+
from websockets.asyncio.client import connect
1212

1313
from .models import GatewayAction, GatewayActionStatus, Relay
1414

@@ -80,9 +80,6 @@ def confirmed(self, msg: str):
8080

8181

8282
class RelayService:
83-
def __init__(self):
84-
self._connections: dict[str, object] = {}
85-
8683
def send_action(self, relay: Relay, action: GatewayAction):
8784
"""
8885
Synchronous wrapper around async_send_action.
@@ -97,48 +94,34 @@ async def async_send_action(
9794
result = SendActionResult()
9895

9996
try:
100-
conn = await self._get_connection(relay)
101-
if not conn:
102-
result.failed(f"No connection available for relay {relay.id}")
103-
return result
104-
105-
await conn.send(json.dumps(action.payload))
106-
result.sent(f"Sent action {action.id} to relay {relay.id}")
107-
108-
response = await asyncio.wait_for(
109-
conn.recv(), timeout=RECEIVE_TIMEOUT_SECONDS
110-
)
111-
response_data = json.loads(response)
112-
113-
if response_data.get("status") in ("created", "processed"):
114-
result.confirmed(f"Action {action.id} confirmed by gateway")
115-
else:
116-
result.failed(
117-
f"Unexpected response status from gateway: {response_data}"
97+
relay_uri = RelayURI(relay)
98+
url = relay_uri.connection_url()
99+
async with connect(
100+
url, compression=None, open_timeout=OPEN_CONNECTION_TIMEOUT_SECONDS
101+
) as conn:
102+
await conn.send(json.dumps(action.payload))
103+
result.sent(f"Sent action {action.id} to relay {relay.id}")
104+
105+
response = await asyncio.wait_for(
106+
conn.recv(), timeout=RECEIVE_TIMEOUT_SECONDS
118107
)
108+
response_data = json.loads(response)
109+
110+
if response_data.get("status") in ("created", "processed"):
111+
result.confirmed(f"Action {action.id} confirmed by gateway")
112+
else:
113+
result.failed(
114+
f"Unexpected response status from gateway: {response_data}"
115+
)
119116

120117
except asyncio.TimeoutError:
121118
result.failed(f"Timeout waiting for response from gateway {relay.id}")
122119

123120
except Exception as e:
124121
result.failed(f"Error sending action to gateway {relay.id}: {e}")
125-
self._connections.pop(relay.id, None)
126122

127123
return result
128124

129-
async def _get_connection(self, relay: Relay) -> ClientConnection:
130-
if relay.id not in self._connections:
131-
self._connections[relay.id] = await self._create_connection(RelayURI(relay))
132-
return self._connections[relay.id]
133-
134-
async def _create_connection(self, relay_uri: RelayURI) -> ClientConnection:
135-
url = relay_uri.connection_url()
136-
websocket = await connect(
137-
url, compression=None, open_timeout=OPEN_CONNECTION_TIMEOUT_SECONDS
138-
)
139-
logger.info(f"Created relay connection for relay {relay_uri.relay.id}")
140-
return websocket
141-
142125
def update_gateway_action(self, action: GatewayAction, result: SendActionResult):
143126
"""Update the GatewayAction based on the SendActionResult."""
144127
action.status = result.status

manage_breast_screening/gateway/tests/test_relay_service.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ def test_send_action_success(self, relay, gateway_action):
8181
mock_ws = AsyncMock(spec=ClientConnection)
8282
mock_ws.recv.return_value = json.dumps({"status": "created"})
8383

84-
with patch.object(subject, "_get_connection", return_value=mock_ws):
84+
with patch(f"{RelayService.__module__}.connect") as mock_connect:
85+
mock_connect.return_value.__aenter__.return_value = mock_ws
8586
subject.send_action(relay, gateway_action)
8687

8788
mock_ws.send.assert_called_once_with(
@@ -101,7 +102,8 @@ def test_send_action_timeout(self, relay, gateway_action):
101102
mock_ws = AsyncMock(spec=ClientConnection)
102103
mock_ws.recv.side_effect = asyncio.TimeoutError
103104

104-
with patch.object(subject, "_get_connection", return_value=mock_ws):
105+
with patch(f"{RelayService.__module__}.connect") as mock_connect:
106+
mock_connect.return_value.__aenter__.return_value = mock_ws
105107
subject.send_action(
106108
relay,
107109
gateway_action,
@@ -122,7 +124,8 @@ def test_send_action_bad_response(self, relay, gateway_action):
122124
mock_ws = AsyncMock(spec=ClientConnection)
123125
mock_ws.recv.return_value = json.dumps({"unexpected": "data"})
124126

125-
with patch.object(subject, "_get_connection", return_value=mock_ws):
127+
with patch(f"{RelayService.__module__}.connect") as mock_connect:
128+
mock_connect.return_value.__aenter__.return_value = mock_ws
126129
subject.send_action(
127130
relay,
128131
gateway_action,
@@ -152,21 +155,3 @@ def test_send_action(self, relay, gateway_action):
152155
== "RelayService.async_send_action"
153156
)
154157
coro.close()
155-
156-
@pytest.mark.asyncio
157-
async def test_connection_is_cached(self):
158-
relay = RelayFactory.build()
159-
manager = RelayService()
160-
161-
fake_connection = AsyncMock()
162-
163-
with patch.object(
164-
manager,
165-
"_create_connection",
166-
AsyncMock(return_value=fake_connection),
167-
):
168-
conn1 = await manager._get_connection(relay)
169-
conn2 = await manager._get_connection(relay)
170-
171-
assert conn1 is conn2
172-
assert relay.id in manager._connections

0 commit comments

Comments
 (0)