|
5 | 5 | connect_websocket, |
6 | 6 | ConnectionOptions, |
7 | 7 | SfuConnectionError, |
| 8 | + SfuJoinError, |
8 | 9 | join_call_coordinator_request, |
9 | 10 | ) |
10 | 11 | from getstream.video.rtc.signaling import SignalingError |
11 | 12 | from getstream.video.rtc.pb.stream.video.sfu.models import models_pb2 |
12 | 13 |
|
13 | 14 |
|
| 15 | +@pytest.fixture |
| 16 | +def mock_ws_client(): |
| 17 | + """Patch WebSocketClient and yield the mock instance.""" |
| 18 | + with patch("getstream.video.rtc.connection_utils.WebSocketClient") as mock_ws_cls: |
| 19 | + mock_ws = AsyncMock() |
| 20 | + mock_ws_cls.return_value = mock_ws |
| 21 | + yield mock_ws |
| 22 | + |
| 23 | + |
| 24 | +@pytest.fixture |
| 25 | +def coordinator_request(): |
| 26 | + """Set up a mock coordinator client that captures the request body.""" |
| 27 | + mock_call = AsyncMock() |
| 28 | + mock_call.call_type = "default" |
| 29 | + mock_call.id = "test_call" |
| 30 | + mock_call.client.stream.api_key = "key" |
| 31 | + mock_call.client.stream.api_secret = "secret" |
| 32 | + mock_call.client.stream.base_url = "https://test.url" |
| 33 | + |
| 34 | + captured_body = {} |
| 35 | + |
| 36 | + with patch("getstream.video.rtc.connection_utils.user_client") as mock_user_client: |
| 37 | + mock_client = AsyncMock() |
| 38 | + |
| 39 | + async def capture_post(*args, **kwargs): |
| 40 | + captured_body.update(kwargs.get("json", {})) |
| 41 | + return AsyncMock() |
| 42 | + |
| 43 | + mock_client.post = capture_post |
| 44 | + mock_user_client.return_value = mock_client |
| 45 | + yield mock_call, captured_body |
| 46 | + |
| 47 | + |
14 | 48 | class TestConnectWebsocket: |
15 | 49 | @pytest.mark.asyncio |
16 | | - async def test_raises_sfu_join_error_on_sfu_full(self): |
| 50 | + async def test_raises_sfu_join_error_on_sfu_full(self, mock_ws_client): |
17 | 51 | """connect_websocket should raise SfuJoinError when SFU is full.""" |
18 | | - from getstream.video.rtc.connection_utils import SfuJoinError |
19 | | - |
20 | | - # Create a models_pb2.Error with SFU_FULL code |
21 | 52 | sfu_error = models_pb2.Error( |
22 | 53 | code=models_pb2.ERROR_CODE_SFU_FULL, |
23 | 54 | message="server is full", |
24 | 55 | should_retry=True, |
25 | 56 | ) |
26 | | - signaling_error = SignalingError( |
27 | | - "Connection failed: server is full", error=sfu_error |
| 57 | + mock_ws_client.connect = AsyncMock( |
| 58 | + side_effect=SignalingError( |
| 59 | + "Connection failed: server is full", error=sfu_error |
| 60 | + ) |
28 | 61 | ) |
29 | 62 |
|
30 | | - with patch( |
31 | | - "getstream.video.rtc.connection_utils.WebSocketClient" |
32 | | - ) as mock_ws_cls: |
33 | | - mock_ws = AsyncMock() |
34 | | - mock_ws.connect = AsyncMock(side_effect=signaling_error) |
35 | | - mock_ws_cls.return_value = mock_ws |
36 | | - |
37 | | - with pytest.raises(SfuJoinError) as exc_info: |
38 | | - await connect_websocket( |
39 | | - token="test_token", |
40 | | - ws_url="wss://test.url", |
41 | | - session_id="test_session", |
42 | | - options=ConnectionOptions(), |
43 | | - ) |
44 | | - |
45 | | - assert exc_info.value.error_code == models_pb2.ERROR_CODE_SFU_FULL |
46 | | - assert exc_info.value.should_retry is True |
47 | | - # SfuJoinError should be a subclass of SfuConnectionError |
48 | | - assert isinstance(exc_info.value, SfuConnectionError) |
| 63 | + with pytest.raises(SfuJoinError) as exc_info: |
| 64 | + await connect_websocket( |
| 65 | + token="test_token", |
| 66 | + ws_url="wss://test.url", |
| 67 | + session_id="test_session", |
| 68 | + options=ConnectionOptions(), |
| 69 | + ) |
| 70 | + |
| 71 | + assert exc_info.value.error_code == models_pb2.ERROR_CODE_SFU_FULL |
| 72 | + assert exc_info.value.should_retry is True |
| 73 | + assert isinstance(exc_info.value, SfuConnectionError) |
49 | 74 |
|
50 | 75 | @pytest.mark.asyncio |
51 | | - async def test_non_retryable_error_propagates_as_signaling_error(self): |
| 76 | + async def test_non_retryable_error_propagates_as_signaling_error( |
| 77 | + self, mock_ws_client |
| 78 | + ): |
52 | 79 | """Non-retryable SignalingError should not become SfuJoinError.""" |
53 | | - from getstream.video.rtc.connection_utils import SfuJoinError |
54 | | - |
55 | | - # Error with non-retryable code (e.g. permission denied) |
56 | 80 | sfu_error = models_pb2.Error( |
57 | 81 | code=models_pb2.ERROR_CODE_PERMISSION_DENIED, |
58 | 82 | message="permission denied", |
59 | 83 | should_retry=False, |
60 | 84 | ) |
61 | | - signaling_error = SignalingError( |
62 | | - "Connection failed: permission denied", error=sfu_error |
| 85 | + mock_ws_client.connect = AsyncMock( |
| 86 | + side_effect=SignalingError( |
| 87 | + "Connection failed: permission denied", error=sfu_error |
| 88 | + ) |
63 | 89 | ) |
64 | 90 |
|
65 | | - with patch( |
66 | | - "getstream.video.rtc.connection_utils.WebSocketClient" |
67 | | - ) as mock_ws_cls: |
68 | | - mock_ws = AsyncMock() |
69 | | - mock_ws.connect = AsyncMock(side_effect=signaling_error) |
70 | | - mock_ws_cls.return_value = mock_ws |
71 | | - |
72 | | - with pytest.raises(SignalingError) as exc_info: |
73 | | - await connect_websocket( |
74 | | - token="test_token", |
75 | | - ws_url="wss://test.url", |
76 | | - session_id="test_session", |
77 | | - options=ConnectionOptions(), |
78 | | - ) |
| 91 | + with pytest.raises(SignalingError) as exc_info: |
| 92 | + await connect_websocket( |
| 93 | + token="test_token", |
| 94 | + ws_url="wss://test.url", |
| 95 | + session_id="test_session", |
| 96 | + options=ConnectionOptions(), |
| 97 | + ) |
79 | 98 |
|
80 | | - assert not isinstance(exc_info.value, SfuJoinError) |
| 99 | + assert not isinstance(exc_info.value, SfuJoinError) |
81 | 100 |
|
82 | 101 |
|
83 | 102 | class TestJoinCallCoordinatorRequest: |
84 | 103 | @pytest.mark.asyncio |
85 | | - async def test_includes_migrating_from_in_body(self): |
| 104 | + async def test_includes_migrating_from_in_body(self, coordinator_request): |
86 | 105 | """migrating_from and migrating_from_list should be included in the request body.""" |
87 | | - mock_call = AsyncMock() |
88 | | - mock_call.call_type = "default" |
89 | | - mock_call.id = "test_call" |
90 | | - mock_call.client.stream.api_key = "key" |
91 | | - mock_call.client.stream.api_secret = "secret" |
92 | | - mock_call.client.stream.base_url = "https://test.url" |
93 | | - |
94 | | - captured_body = {} |
95 | | - |
96 | | - with patch( |
97 | | - "getstream.video.rtc.connection_utils.user_client" |
98 | | - ) as mock_user_client: |
99 | | - mock_client = AsyncMock() |
100 | | - |
101 | | - async def capture_post(*args, **kwargs): |
102 | | - captured_body.update(kwargs.get("json", {})) |
103 | | - return AsyncMock() |
104 | | - |
105 | | - mock_client.post = capture_post |
106 | | - mock_user_client.return_value = mock_client |
107 | | - |
108 | | - await join_call_coordinator_request( |
109 | | - call=mock_call, |
110 | | - user_id="user1", |
111 | | - location="auto", |
112 | | - migrating_from="sfu-london-1", |
113 | | - migrating_from_list=["sfu-london-1", "sfu-paris-2"], |
114 | | - ) |
| 106 | + mock_call, captured_body = coordinator_request |
| 107 | + |
| 108 | + await join_call_coordinator_request( |
| 109 | + call=mock_call, |
| 110 | + user_id="user1", |
| 111 | + location="auto", |
| 112 | + migrating_from="sfu-london-1", |
| 113 | + migrating_from_list=["sfu-london-1", "sfu-paris-2"], |
| 114 | + ) |
115 | 115 |
|
116 | 116 | assert captured_body["migrating_from"] == "sfu-london-1" |
117 | 117 | assert captured_body["migrating_from_list"] == ["sfu-london-1", "sfu-paris-2"] |
118 | 118 |
|
119 | 119 | @pytest.mark.asyncio |
120 | | - async def test_omits_migrating_from_when_not_provided(self): |
| 120 | + async def test_omits_migrating_from_when_not_provided(self, coordinator_request): |
121 | 121 | """migrating_from should not appear in body when not provided.""" |
122 | | - mock_call = AsyncMock() |
123 | | - mock_call.call_type = "default" |
124 | | - mock_call.id = "test_call" |
125 | | - mock_call.client.stream.api_key = "key" |
126 | | - mock_call.client.stream.api_secret = "secret" |
127 | | - mock_call.client.stream.base_url = "https://test.url" |
128 | | - |
129 | | - captured_body = {} |
130 | | - |
131 | | - with patch( |
132 | | - "getstream.video.rtc.connection_utils.user_client" |
133 | | - ) as mock_user_client: |
134 | | - mock_client = AsyncMock() |
135 | | - |
136 | | - async def capture_post(*args, **kwargs): |
137 | | - captured_body.update(kwargs.get("json", {})) |
138 | | - return AsyncMock() |
139 | | - |
140 | | - mock_client.post = capture_post |
141 | | - mock_user_client.return_value = mock_client |
142 | | - |
143 | | - await join_call_coordinator_request( |
144 | | - call=mock_call, |
145 | | - user_id="user1", |
146 | | - location="auto", |
147 | | - ) |
| 122 | + mock_call, captured_body = coordinator_request |
| 123 | + |
| 124 | + await join_call_coordinator_request( |
| 125 | + call=mock_call, |
| 126 | + user_id="user1", |
| 127 | + location="auto", |
| 128 | + ) |
148 | 129 |
|
149 | 130 | assert "migrating_from" not in captured_body |
150 | 131 | assert "migrating_from_list" not in captured_body |
0 commit comments