Skip to content

Commit bc6024c

Browse files
authored
Merge pull request #222 from GetStream/feat/retry-connect-on-sfu-full
Retry connect() on SFU full by requesting a different SFU
2 parents 64754d9 + ab392f2 commit bc6024c

File tree

8 files changed

+542
-46
lines changed

8 files changed

+542
-46
lines changed

getstream/video/rtc/connection_manager.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@
1919
from getstream.video.rtc.connection_utils import (
2020
ConnectionState,
2121
SfuConnectionError,
22+
SfuJoinError,
2223
ConnectionOptions,
2324
connect_websocket,
2425
join_call,
2526
watch_call,
2627
)
28+
from getstream.video.rtc.coordinator.backoff import exp_backoff
2729
from getstream.video.rtc.track_util import (
2830
fix_sdp_msid_semantic,
2931
fix_sdp_rtcp_fb,
@@ -55,6 +57,7 @@ def __init__(
5557
user_id: Optional[str] = None,
5658
create: bool = True,
5759
subscription_config: Optional[SubscriptionConfig] = None,
60+
max_join_retries: int = 3,
5861
**kwargs: Any,
5962
):
6063
super().__init__()
@@ -68,6 +71,9 @@ def __init__(
6871
self.session_id: str = str(uuid.uuid4())
6972
self.join_response: Optional[JoinCallResponse] = None
7073
self.local_sfu: bool = False # Local SFU flag for development
74+
if max_join_retries < 0:
75+
raise ValueError("max_join_retries must be >= 0")
76+
self._max_join_retries: int = max_join_retries
7177

7278
# Private attributes
7379
self._connection_state: ConnectionState = ConnectionState.IDLE
@@ -282,6 +288,7 @@ async def _connect_internal(
282288
ws_url: Optional[str] = None,
283289
token: Optional[str] = None,
284290
session_id: Optional[str] = None,
291+
migrating_from_list: Optional[list] = None,
285292
) -> None:
286293
"""
287294
Internal connection method that handles the core connection logic.
@@ -318,12 +325,15 @@ async def _connect_internal(
318325
if not (ws_url or token):
319326
if self.user_id is None:
320327
raise ValueError("user_id is required for joining a call")
328+
last_failed = migrating_from_list[-1] if migrating_from_list else None
321329
join_response = await join_call(
322330
self.call,
323331
self.user_id,
324332
"auto",
325333
self.create,
326334
self.local_sfu,
335+
migrating_from=last_failed,
336+
migrating_from_list=migrating_from_list,
327337
**self.kwargs,
328338
)
329339
ws_url = join_response.data.credentials.server.ws_endpoint
@@ -395,6 +405,8 @@ async def _connect_internal(
395405
logger.exception(f"No join response from WebSocket: {sfu_event}")
396406

397407
logger.debug(f"WebSocket connected successfully to {ws_url}")
408+
except SfuJoinError:
409+
raise
398410
except Exception as e:
399411
logger.exception(f"Failed to connect WebSocket to {ws_url}: {e}")
400412
raise SfuConnectionError(f"WebSocket connection failed: {e}") from e
@@ -427,7 +439,8 @@ async def connect(self):
427439
Connect to SFU.
428440
429441
This method automatically handles retry logic for transient errors
430-
like "server is full" and network issues.
442+
like "server is full" by requesting a different SFU from the
443+
coordinator.
431444
"""
432445
logger.info("Connecting to SFU")
433446
# Fire-and-forget the coordinator WS connection so we don't block here
@@ -445,7 +458,54 @@ def _on_coordinator_task_done(task: asyncio.Task):
445458
logger.exception("Coordinator WS task failed")
446459

447460
self._coordinator_task.add_done_callback(_on_coordinator_task_done)
448-
await self._connect_internal()
461+
462+
await self._connect_with_sfu_reassignment()
463+
464+
async def _connect_with_sfu_reassignment(self) -> None:
465+
"""Try connecting to SFU, reassigning to a different one on failure."""
466+
failed_sfus: list[str] = []
467+
468+
# First attempt without delay
469+
attempt = 0
470+
try:
471+
await self._connect_internal()
472+
return
473+
except SfuJoinError as e:
474+
self._handle_join_failure(e, attempt, failed_sfus)
475+
if self._max_join_retries == 0:
476+
raise
477+
478+
# Retries with exponential backoff, requesting a different SFU
479+
async for delay in exp_backoff(max_retries=self._max_join_retries, base=0.5):
480+
attempt += 1
481+
logger.info(f"Retrying in {delay}s with different SFU...")
482+
await asyncio.sleep(delay)
483+
try:
484+
await self._connect_internal(
485+
migrating_from_list=failed_sfus if failed_sfus else None,
486+
)
487+
return
488+
except SfuJoinError as e:
489+
self._handle_join_failure(e, attempt, failed_sfus)
490+
if attempt >= self._max_join_retries:
491+
raise
492+
493+
def _handle_join_failure(
494+
self, error: SfuJoinError, attempt: int, failed_sfus: list[str]
495+
) -> None:
496+
"""Track a failed SFU and clean up partial connection state."""
497+
if self.join_response and self.join_response.credentials:
498+
edge = self.join_response.credentials.server.edge_name
499+
if edge and edge not in failed_sfus:
500+
failed_sfus.append(edge)
501+
logger.warning(
502+
f"SFU join failed (attempt {attempt + 1}/{1 + self._max_join_retries}, "
503+
f"code={error.error_code}). Failed SFUs: {failed_sfus}"
504+
)
505+
if self._ws_client:
506+
self._ws_client.close()
507+
self._ws_client = None
508+
self.connection_state = ConnectionState.IDLE
449509

450510
async def wait(self):
451511
"""

getstream/video/rtc/connection_utils.py

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -58,20 +58,6 @@
5858
"connect_websocket",
5959
]
6060

61-
# Private constants - internal use only
62-
_RETRYABLE_ERROR_PATTERNS = [
63-
"server is full",
64-
"server overloaded",
65-
"capacity exceeded",
66-
"try again later",
67-
"service unavailable",
68-
"connection timeout",
69-
"network error",
70-
"temporary failure",
71-
"connection refused",
72-
"connection reset",
73-
]
74-
7561

7662
# Public classes and exceptions
7763
class ConnectionState(Enum):
@@ -94,6 +80,22 @@ class SfuConnectionError(Exception):
9480
pass
9581

9682

83+
class SfuJoinError(SfuConnectionError):
84+
"""Raised when SFU join fails with a retryable error code."""
85+
86+
def __init__(self, message: str, error_code: int = 0, should_retry: bool = False):
87+
super().__init__(message)
88+
self.error_code = error_code
89+
self.should_retry = should_retry
90+
91+
92+
_RETRYABLE_SFU_ERROR_CODES = {
93+
700, # ERROR_CODE_SFU_FULL
94+
600, # ERROR_CODE_SFU_SHUTTING_DOWN
95+
301, # ERROR_CODE_CALL_PARTICIPANT_LIMIT_REACHED
96+
}
97+
98+
9799
@dataclass
98100
class ConnectionOptions:
99101
"""Options for the connection process."""
@@ -175,6 +177,8 @@ async def join_call_coordinator_request(
175177
notify: Optional[bool] = None,
176178
video: Optional[bool] = None,
177179
location: Optional[str] = None,
180+
migrating_from: Optional[str] = None,
181+
migrating_from_list: Optional[list] = None,
178182
) -> StreamResponse[JoinCallResponse]:
179183
"""Make a request to join a call via the coordinator.
180184
@@ -208,6 +212,10 @@ async def join_call_coordinator_request(
208212
video=video,
209213
data=data,
210214
)
215+
if migrating_from:
216+
json_body["migrating_from"] = migrating_from
217+
if migrating_from_list:
218+
json_body["migrating_from_list"] = migrating_from_list
211219

212220
# Make the POST request to join the call
213221
return await client.post(
@@ -423,6 +431,8 @@ async def connect_websocket(
423431
"""
424432
logger.info(f"Connecting to WebSocket at {ws_url}")
425433

434+
ws_client = None
435+
success = False
426436
try:
427437
# Create JoinRequest for WebSocket connection
428438
join_request = await create_join_request(token, session_id)
@@ -448,34 +458,24 @@ async def connect_websocket(
448458
sfu_event = await ws_client.connect()
449459

450460
logger.debug("WebSocket connection established")
461+
success = True
451462
return ws_client, sfu_event
452463

464+
except SignalingError as e:
465+
if (
466+
e.error
467+
and hasattr(e.error, "code")
468+
and e.error.code in _RETRYABLE_SFU_ERROR_CODES
469+
):
470+
raise SfuJoinError(
471+
str(e),
472+
error_code=e.error.code,
473+
should_retry=True,
474+
) from e
475+
raise
453476
except Exception as e:
454477
logger.error(f"Failed to connect WebSocket to {ws_url}: {e}")
455478
raise SignalingError(f"WebSocket connection failed: {e}")
456-
457-
458-
# Private functions
459-
def _is_retryable(retry_state: Any) -> bool:
460-
"""Check if an error should be retried.
461-
462-
Args:
463-
retry_state: The retry state object from tenacity
464-
465-
Returns:
466-
True if the error should be retried, False otherwise
467-
"""
468-
# Extract the actual exception from the retry state
469-
if hasattr(retry_state, "outcome") and retry_state.outcome.failed:
470-
error = retry_state.outcome.exception()
471-
else:
472-
return False
473-
474-
# Import here to avoid circular imports
475-
from getstream.video.rtc.signaling import SignalingError
476-
477-
if not isinstance(error, (SignalingError, SfuConnectionError)):
478-
return False
479-
480-
error_message = str(error).lower()
481-
return any(pattern in error_message for pattern in _RETRYABLE_ERROR_PATTERNS)
479+
finally:
480+
if ws_client and not success:
481+
ws_client.close()

getstream/video/rtc/signaling.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
class SignalingError(Exception):
2727
"""Exception raised for errors in the signaling process."""
2828

29-
pass
29+
def __init__(self, message: str, error=None):
30+
super().__init__(message)
31+
self.error = error
3032

3133

3234
class WebSocketClient(StreamAsyncIOEventEmitter):
@@ -111,8 +113,10 @@ async def connect(self):
111113

112114
# Check if the first message is an error
113115
if self.first_message and self.first_message.HasField("error"):
114-
error_msg = self.first_message.error.error.message
115-
raise SignalingError(f"Connection failed: {error_msg}")
116+
sfu_error = self.first_message.error.error
117+
raise SignalingError(
118+
f"Connection failed: {sfu_error.message}", error=sfu_error
119+
)
116120

117121
# Check if we got join_response
118122
if self.first_message and self.first_message.HasField("join_response"):

scripts/test_sfu_connect.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Utility script for testing SFU connection and retry behavior.
4+
5+
Connects to a call as a given user and logs each step of the connection
6+
process — useful for verifying SFU assignment, retry on transient errors
7+
(e.g. SFU_FULL), and reassignment via the coordinator.
8+
9+
Environment variables
10+
---------------------
11+
STREAM_API_KEY — Stream API key (required)
12+
STREAM_API_SECRET — Stream API secret (required)
13+
STREAM_BASE_URL — Coordinator URL (default: Stream cloud).
14+
Set to http://127.0.0.1:3030 for a local coordinator.
15+
USER_ID — User ID to join as (default: "test-user").
16+
CALL_TYPE — Call type (default: "default").
17+
CALL_ID — Call ID. If not set, a random UUID is generated.
18+
19+
Usage
20+
-----
21+
# Connect via cloud coordinator
22+
STREAM_API_KEY=... STREAM_API_SECRET=... \\
23+
uv run --extra webrtc python scripts/test_sfu_connect.py
24+
25+
# Connect via local coordinator
26+
STREAM_BASE_URL=http://127.0.0.1:3030 \\
27+
uv run --extra webrtc python scripts/test_sfu_connect.py
28+
"""
29+
30+
import asyncio
31+
import logging
32+
import os
33+
import uuid
34+
35+
from dotenv import load_dotenv
36+
37+
from getstream import AsyncStream
38+
from getstream.models import CallRequest
39+
from getstream.video.rtc import ConnectionManager
40+
41+
load_dotenv()
42+
43+
logging.basicConfig(
44+
level=logging.INFO,
45+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
46+
)
47+
logger = logging.getLogger(__name__)
48+
49+
50+
async def run():
51+
base_url = os.getenv("STREAM_BASE_URL")
52+
user_id = os.getenv("USER_ID", "test-user")
53+
call_type = os.getenv("CALL_TYPE", "default")
54+
call_id = os.getenv("CALL_ID", str(uuid.uuid4()))
55+
56+
logger.info("Configuration:")
57+
logger.info(f" Coordinator: {base_url or 'cloud (default)'}")
58+
logger.info(f" User: {user_id}")
59+
logger.info(f" Call: {call_type}:{call_id}")
60+
61+
client_kwargs = {}
62+
if base_url:
63+
client_kwargs["base_url"] = base_url
64+
65+
client = AsyncStream(timeout=10.0, **client_kwargs)
66+
67+
call = client.video.call(call_type, call_id)
68+
logger.info("Creating call...")
69+
await call.get_or_create(data=CallRequest(created_by_id=user_id))
70+
logger.info("Call created")
71+
72+
cm = ConnectionManager(
73+
call=call,
74+
user_id=user_id,
75+
create=False,
76+
)
77+
78+
logger.info("Connecting to SFU...")
79+
80+
async with cm:
81+
join = cm.join_response
82+
if join and join.credentials:
83+
logger.info(f"Connected to SFU: {join.credentials.server.edge_name}")
84+
logger.info(f" WS endpoint: {join.credentials.server.ws_endpoint}")
85+
logger.info(f" Session ID: {cm.session_id}")
86+
87+
logger.info("Holding connection for 3s...")
88+
await asyncio.sleep(3)
89+
90+
logger.info("Leaving call")
91+
92+
logger.info("Done")
93+
94+
95+
if __name__ == "__main__":
96+
asyncio.run(run())

0 commit comments

Comments
 (0)