Skip to content

Commit 61da611

Browse files
committed
refactor: extract _handle_join_failure from connect() retry loop
1 parent 311447d commit 61da611

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

getstream/video/rtc/connection_manager.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -468,20 +468,7 @@ def _on_coordinator_task_done(task: asyncio.Task):
468468
return
469469
except SfuJoinError as e:
470470
last_error = e
471-
# Track the failed SFU
472-
if self.join_response and self.join_response.credentials:
473-
edge = self.join_response.credentials.server.edge_name
474-
if edge and edge not in failed_sfus:
475-
failed_sfus.append(edge)
476-
logger.warning(
477-
f"SFU join failed (attempt {attempt + 1}/{1 + self._max_join_retries}, "
478-
f"code={e.error_code}). Failed SFUs: {failed_sfus}"
479-
)
480-
# Clean up partial state before retry
481-
if self._ws_client:
482-
self._ws_client.close()
483-
self._ws_client = None
484-
self.connection_state = ConnectionState.IDLE
471+
self._handle_join_failure(e, attempt, failed_sfus)
485472

486473
if attempt < self._max_join_retries:
487474
delay = 0.5 * (2.0**attempt)
@@ -490,6 +477,23 @@ def _on_coordinator_task_done(task: asyncio.Task):
490477

491478
raise last_error # type: ignore[misc]
492479

480+
def _handle_join_failure(
481+
self, error: SfuJoinError, attempt: int, failed_sfus: list[str]
482+
) -> None:
483+
"""Track a failed SFU and clean up partial connection state."""
484+
if self.join_response and self.join_response.credentials:
485+
edge = self.join_response.credentials.server.edge_name
486+
if edge and edge not in failed_sfus:
487+
failed_sfus.append(edge)
488+
logger.warning(
489+
f"SFU join failed (attempt {attempt + 1}/{1 + self._max_join_retries}, "
490+
f"code={error.error_code}). Failed SFUs: {failed_sfus}"
491+
)
492+
if self._ws_client:
493+
self._ws_client.close()
494+
self._ws_client = None
495+
self.connection_state = ConnectionState.IDLE
496+
493497
async def wait(self):
494498
"""
495499
Wait until the connection is over.

0 commit comments

Comments
 (0)