Skip to content

Commit ab8a27c

Browse files
committed
refactor: use fixtures in test_connection_utils, snapshot mutable list in test_connection_manager
Extract mock_ws_client and coordinator_request fixtures to reduce duplication. Snapshot migrating_from_list per attempt to avoid mutable reference aliasing in assertions.
1 parent efd9b09 commit ab8a27c

File tree

2 files changed

+87
-104
lines changed

2 files changed

+87
-104
lines changed

tests/test_connection_manager.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ async def test_retries_on_sfu_join_error_and_passes_failed_sfus(
5555
async def mock_connect_internal(migrating_from_list=None, **kwargs):
5656
nonlocal call_count
5757
call_count += 1
58-
received_migrating_from_list.append(migrating_from_list)
58+
received_migrating_from_list.append(
59+
list(migrating_from_list) if migrating_from_list else None
60+
)
5961

6062
if call_count <= 2:
6163
mock_join_response = MagicMock()
@@ -76,7 +78,7 @@ async def mock_connect_internal(migrating_from_list=None, **kwargs):
7678

7779
assert call_count == 3
7880
assert received_migrating_from_list[0] is None
79-
assert "sfu-node-1" in received_migrating_from_list[1]
81+
assert received_migrating_from_list[1] == ["sfu-node-1"]
8082
assert received_migrating_from_list[2] == ["sfu-node-1", "sfu-node-2"]
8183

8284
@pytest.mark.asyncio

tests/test_connection_utils.py

Lines changed: 83 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -5,146 +5,127 @@
55
connect_websocket,
66
ConnectionOptions,
77
SfuConnectionError,
8+
SfuJoinError,
89
join_call_coordinator_request,
910
)
1011
from getstream.video.rtc.signaling import SignalingError
1112
from getstream.video.rtc.pb.stream.video.sfu.models import models_pb2
1213

1314

15+
@pytest.fixture
16+
def mock_ws_client():
17+
"""Patch WebSocketClient and yield the mock instance."""
18+
with patch("getstream.video.rtc.connection_utils.WebSocketClient") as mock_ws_cls:
19+
mock_ws = AsyncMock()
20+
mock_ws_cls.return_value = mock_ws
21+
yield mock_ws
22+
23+
24+
@pytest.fixture
25+
def coordinator_request():
26+
"""Set up a mock coordinator client that captures the request body."""
27+
mock_call = AsyncMock()
28+
mock_call.call_type = "default"
29+
mock_call.id = "test_call"
30+
mock_call.client.stream.api_key = "key"
31+
mock_call.client.stream.api_secret = "secret"
32+
mock_call.client.stream.base_url = "https://test.url"
33+
34+
captured_body = {}
35+
36+
with patch("getstream.video.rtc.connection_utils.user_client") as mock_user_client:
37+
mock_client = AsyncMock()
38+
39+
async def capture_post(*args, **kwargs):
40+
captured_body.update(kwargs.get("json", {}))
41+
return AsyncMock()
42+
43+
mock_client.post = capture_post
44+
mock_user_client.return_value = mock_client
45+
yield mock_call, captured_body
46+
47+
1448
class TestConnectWebsocket:
1549
@pytest.mark.asyncio
16-
async def test_raises_sfu_join_error_on_sfu_full(self):
50+
async def test_raises_sfu_join_error_on_sfu_full(self, mock_ws_client):
1751
"""connect_websocket should raise SfuJoinError when SFU is full."""
18-
from getstream.video.rtc.connection_utils import SfuJoinError
19-
20-
# Create a models_pb2.Error with SFU_FULL code
2152
sfu_error = models_pb2.Error(
2253
code=models_pb2.ERROR_CODE_SFU_FULL,
2354
message="server is full",
2455
should_retry=True,
2556
)
26-
signaling_error = SignalingError(
27-
"Connection failed: server is full", error=sfu_error
57+
mock_ws_client.connect = AsyncMock(
58+
side_effect=SignalingError(
59+
"Connection failed: server is full", error=sfu_error
60+
)
2861
)
2962

30-
with patch(
31-
"getstream.video.rtc.connection_utils.WebSocketClient"
32-
) as mock_ws_cls:
33-
mock_ws = AsyncMock()
34-
mock_ws.connect = AsyncMock(side_effect=signaling_error)
35-
mock_ws_cls.return_value = mock_ws
36-
37-
with pytest.raises(SfuJoinError) as exc_info:
38-
await connect_websocket(
39-
token="test_token",
40-
ws_url="wss://test.url",
41-
session_id="test_session",
42-
options=ConnectionOptions(),
43-
)
44-
45-
assert exc_info.value.error_code == models_pb2.ERROR_CODE_SFU_FULL
46-
assert exc_info.value.should_retry is True
47-
# SfuJoinError should be a subclass of SfuConnectionError
48-
assert isinstance(exc_info.value, SfuConnectionError)
63+
with pytest.raises(SfuJoinError) as exc_info:
64+
await connect_websocket(
65+
token="test_token",
66+
ws_url="wss://test.url",
67+
session_id="test_session",
68+
options=ConnectionOptions(),
69+
)
70+
71+
assert exc_info.value.error_code == models_pb2.ERROR_CODE_SFU_FULL
72+
assert exc_info.value.should_retry is True
73+
assert isinstance(exc_info.value, SfuConnectionError)
4974

5075
@pytest.mark.asyncio
51-
async def test_non_retryable_error_propagates_as_signaling_error(self):
76+
async def test_non_retryable_error_propagates_as_signaling_error(
77+
self, mock_ws_client
78+
):
5279
"""Non-retryable SignalingError should not become SfuJoinError."""
53-
from getstream.video.rtc.connection_utils import SfuJoinError
54-
55-
# Error with non-retryable code (e.g. permission denied)
5680
sfu_error = models_pb2.Error(
5781
code=models_pb2.ERROR_CODE_PERMISSION_DENIED,
5882
message="permission denied",
5983
should_retry=False,
6084
)
61-
signaling_error = SignalingError(
62-
"Connection failed: permission denied", error=sfu_error
85+
mock_ws_client.connect = AsyncMock(
86+
side_effect=SignalingError(
87+
"Connection failed: permission denied", error=sfu_error
88+
)
6389
)
6490

65-
with patch(
66-
"getstream.video.rtc.connection_utils.WebSocketClient"
67-
) as mock_ws_cls:
68-
mock_ws = AsyncMock()
69-
mock_ws.connect = AsyncMock(side_effect=signaling_error)
70-
mock_ws_cls.return_value = mock_ws
71-
72-
with pytest.raises(SignalingError) as exc_info:
73-
await connect_websocket(
74-
token="test_token",
75-
ws_url="wss://test.url",
76-
session_id="test_session",
77-
options=ConnectionOptions(),
78-
)
91+
with pytest.raises(SignalingError) as exc_info:
92+
await connect_websocket(
93+
token="test_token",
94+
ws_url="wss://test.url",
95+
session_id="test_session",
96+
options=ConnectionOptions(),
97+
)
7998

80-
assert not isinstance(exc_info.value, SfuJoinError)
99+
assert not isinstance(exc_info.value, SfuJoinError)
81100

82101

83102
class TestJoinCallCoordinatorRequest:
84103
@pytest.mark.asyncio
85-
async def test_includes_migrating_from_in_body(self):
104+
async def test_includes_migrating_from_in_body(self, coordinator_request):
86105
"""migrating_from and migrating_from_list should be included in the request body."""
87-
mock_call = AsyncMock()
88-
mock_call.call_type = "default"
89-
mock_call.id = "test_call"
90-
mock_call.client.stream.api_key = "key"
91-
mock_call.client.stream.api_secret = "secret"
92-
mock_call.client.stream.base_url = "https://test.url"
93-
94-
captured_body = {}
95-
96-
with patch(
97-
"getstream.video.rtc.connection_utils.user_client"
98-
) as mock_user_client:
99-
mock_client = AsyncMock()
100-
101-
async def capture_post(*args, **kwargs):
102-
captured_body.update(kwargs.get("json", {}))
103-
return AsyncMock()
104-
105-
mock_client.post = capture_post
106-
mock_user_client.return_value = mock_client
107-
108-
await join_call_coordinator_request(
109-
call=mock_call,
110-
user_id="user1",
111-
location="auto",
112-
migrating_from="sfu-london-1",
113-
migrating_from_list=["sfu-london-1", "sfu-paris-2"],
114-
)
106+
mock_call, captured_body = coordinator_request
107+
108+
await join_call_coordinator_request(
109+
call=mock_call,
110+
user_id="user1",
111+
location="auto",
112+
migrating_from="sfu-london-1",
113+
migrating_from_list=["sfu-london-1", "sfu-paris-2"],
114+
)
115115

116116
assert captured_body["migrating_from"] == "sfu-london-1"
117117
assert captured_body["migrating_from_list"] == ["sfu-london-1", "sfu-paris-2"]
118118

119119
@pytest.mark.asyncio
120-
async def test_omits_migrating_from_when_not_provided(self):
120+
async def test_omits_migrating_from_when_not_provided(self, coordinator_request):
121121
"""migrating_from should not appear in body when not provided."""
122-
mock_call = AsyncMock()
123-
mock_call.call_type = "default"
124-
mock_call.id = "test_call"
125-
mock_call.client.stream.api_key = "key"
126-
mock_call.client.stream.api_secret = "secret"
127-
mock_call.client.stream.base_url = "https://test.url"
128-
129-
captured_body = {}
130-
131-
with patch(
132-
"getstream.video.rtc.connection_utils.user_client"
133-
) as mock_user_client:
134-
mock_client = AsyncMock()
135-
136-
async def capture_post(*args, **kwargs):
137-
captured_body.update(kwargs.get("json", {}))
138-
return AsyncMock()
139-
140-
mock_client.post = capture_post
141-
mock_user_client.return_value = mock_client
142-
143-
await join_call_coordinator_request(
144-
call=mock_call,
145-
user_id="user1",
146-
location="auto",
147-
)
122+
mock_call, captured_body = coordinator_request
123+
124+
await join_call_coordinator_request(
125+
call=mock_call,
126+
user_id="user1",
127+
location="auto",
128+
)
148129

149130
assert "migrating_from" not in captured_body
150131
assert "migrating_from_list" not in captured_body

0 commit comments

Comments
 (0)