Skip to content

Commit bca75d7

Browse files
authored
Merge pull request #1135 from NHSDigital/fix/make-relay-connections-shortlived
Use connect function as context manager
2 parents 3f1b6f4 + fe2dd8e commit bca75d7

File tree

2 files changed

+25
-57
lines changed

2 files changed

+25
-57
lines changed

manage_breast_screening/gateway/relay_service.py

Lines changed: 19 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import urllib.parse
99
from datetime import datetime, timezone
1010

11-
from websockets.asyncio.client import ClientConnection, connect
11+
from websockets.asyncio.client import connect
1212

1313
from .models import GatewayAction, GatewayActionStatus, Relay
1414

@@ -80,9 +80,6 @@ def confirmed(self, msg: str):
8080

8181

8282
class RelayService:
83-
def __init__(self):
84-
self._connections: dict[str, object] = {}
85-
8683
def send_action(self, relay: Relay, action: GatewayAction):
8784
"""
8885
Synchronous wrapper around async_send_action.
@@ -97,48 +94,34 @@ async def async_send_action(
9794
result = SendActionResult()
9895

9996
try:
100-
conn = await self._get_connection(relay)
101-
if not conn:
102-
result.failed(f"No connection available for relay {relay.id}")
103-
return result
104-
105-
await conn.send(json.dumps(action.payload))
106-
result.sent(f"Sent action {action.id} to relay {relay.id}")
107-
108-
response = await asyncio.wait_for(
109-
conn.recv(), timeout=RECEIVE_TIMEOUT_SECONDS
110-
)
111-
response_data = json.loads(response)
112-
113-
if response_data.get("status") in ("created", "processed"):
114-
result.confirmed(f"Action {action.id} confirmed by gateway")
115-
else:
116-
result.failed(
117-
f"Unexpected response status from gateway: {response_data}"
97+
relay_uri = RelayURI(relay)
98+
url = relay_uri.connection_url()
99+
async with connect(
100+
url, compression=None, open_timeout=OPEN_CONNECTION_TIMEOUT_SECONDS
101+
) as conn:
102+
await conn.send(json.dumps(action.payload))
103+
result.sent(f"Sent action {action.id} to relay {relay.id}")
104+
105+
response = await asyncio.wait_for(
106+
conn.recv(), timeout=RECEIVE_TIMEOUT_SECONDS
118107
)
108+
response_data = json.loads(response)
109+
110+
if response_data.get("status") in ("created", "processed"):
111+
result.confirmed(f"Action {action.id} confirmed by gateway")
112+
else:
113+
result.failed(
114+
f"Unexpected response status from gateway: {response_data}"
115+
)
119116

120117
except asyncio.TimeoutError:
121118
result.failed(f"Timeout waiting for response from gateway {relay.id}")
122119

123120
except Exception as e:
124121
result.failed(f"Error sending action to gateway {relay.id}: {e}")
125-
self._connections.pop(relay.id, None)
126122

127123
return result
128124

129-
async def _get_connection(self, relay: Relay) -> ClientConnection:
130-
if relay.id not in self._connections:
131-
self._connections[relay.id] = await self._create_connection(RelayURI(relay))
132-
return self._connections[relay.id]
133-
134-
async def _create_connection(self, relay_uri: RelayURI) -> ClientConnection:
135-
url = relay_uri.connection_url()
136-
websocket = await connect(
137-
url, compression=None, open_timeout=OPEN_CONNECTION_TIMEOUT_SECONDS
138-
)
139-
logger.info(f"Created relay connection for relay {relay_uri.relay.id}")
140-
return websocket
141-
142125
def update_gateway_action(self, action: GatewayAction, result: SendActionResult):
143126
"""Update the GatewayAction based on the SendActionResult."""
144127
action.status = result.status

manage_breast_screening/gateway/tests/test_relay_service.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ def test_send_action_success(self, relay, gateway_action):
8181
mock_ws = AsyncMock(spec=ClientConnection)
8282
mock_ws.recv.return_value = json.dumps({"status": "created"})
8383

84-
with patch.object(subject, "_get_connection", return_value=mock_ws):
84+
with patch(f"{RelayService.__module__}.connect") as mock_connect:
85+
mock_connect.return_value.__aenter__.return_value = mock_ws
8586
subject.send_action(relay, gateway_action)
8687

8788
mock_ws.send.assert_called_once_with(
@@ -101,7 +102,8 @@ def test_send_action_timeout(self, relay, gateway_action):
101102
mock_ws = AsyncMock(spec=ClientConnection)
102103
mock_ws.recv.side_effect = asyncio.TimeoutError
103104

104-
with patch.object(subject, "_get_connection", return_value=mock_ws):
105+
with patch(f"{RelayService.__module__}.connect") as mock_connect:
106+
mock_connect.return_value.__aenter__.return_value = mock_ws
105107
subject.send_action(
106108
relay,
107109
gateway_action,
@@ -122,7 +124,8 @@ def test_send_action_bad_response(self, relay, gateway_action):
122124
mock_ws = AsyncMock(spec=ClientConnection)
123125
mock_ws.recv.return_value = json.dumps({"unexpected": "data"})
124126

125-
with patch.object(subject, "_get_connection", return_value=mock_ws):
127+
with patch(f"{RelayService.__module__}.connect") as mock_connect:
128+
mock_connect.return_value.__aenter__.return_value = mock_ws
126129
subject.send_action(
127130
relay,
128131
gateway_action,
@@ -152,21 +155,3 @@ def test_send_action(self, relay, gateway_action):
152155
== "RelayService.async_send_action"
153156
)
154157
coro.close()
155-
156-
@pytest.mark.asyncio
157-
async def test_connection_is_cached(self):
158-
relay = RelayFactory.build()
159-
manager = RelayService()
160-
161-
fake_connection = AsyncMock()
162-
163-
with patch.object(
164-
manager,
165-
"_create_connection",
166-
AsyncMock(return_value=fake_connection),
167-
):
168-
conn1 = await manager._get_connection(relay)
169-
conn2 = await manager._get_connection(relay)
170-
171-
assert conn1 is conn2
172-
assert relay.id in manager._connections

0 commit comments

Comments
 (0)