|
12 | 12 | from openhands.sdk.agent.acp_agent import ( |
13 | 13 | ACPAgent, |
14 | 14 | _OpenHandsACPBridge, |
| 15 | + _is_retriable_connection_error, |
15 | 16 | _resolve_bypass_mode, |
16 | 17 | _select_auth_method, |
17 | 18 | ) |
@@ -1387,3 +1388,194 @@ def test_serialization_roundtrip(self): |
1387 | 1388 | restored = AgentBase.model_validate_json(dumped) |
1388 | 1389 | assert isinstance(restored, ACPAgent) |
1389 | 1390 | assert restored.acp_session_mode == "full-access" |
| 1391 | + |
| 1392 | + |
| 1393 | +# --------------------------------------------------------------------------- |
| 1394 | +# Connection retry logic |
| 1395 | +# --------------------------------------------------------------------------- |
| 1396 | + |
| 1397 | + |
| 1398 | +class TestIsRetriableConnectionError: |
| 1399 | + """Test _is_retriable_connection_error classification.""" |
| 1400 | + |
| 1401 | + def test_oserror_is_retriable(self): |
| 1402 | + assert _is_retriable_connection_error(OSError("Connection reset")) |
| 1403 | + |
| 1404 | + def test_connection_error_is_retriable(self): |
| 1405 | + assert _is_retriable_connection_error(ConnectionError("Connection refused")) |
| 1406 | + |
| 1407 | + def test_broken_pipe_is_retriable(self): |
| 1408 | + assert _is_retriable_connection_error(BrokenPipeError("Broken pipe")) |
| 1409 | + |
| 1410 | + def test_eof_error_is_retriable(self): |
| 1411 | + assert _is_retriable_connection_error(EOFError("Unexpected EOF")) |
| 1412 | + |
| 1413 | + def test_connection_closed_message_is_retriable(self): |
| 1414 | + assert _is_retriable_connection_error(RuntimeError("connection closed by peer")) |
| 1415 | + |
| 1416 | + def test_server_disconnected_is_retriable(self): |
| 1417 | + assert _is_retriable_connection_error(Exception("server disconnected")) |
| 1418 | + |
| 1419 | + def test_usage_policy_not_retriable(self): |
| 1420 | + assert not _is_retriable_connection_error( |
| 1421 | + RuntimeError("Usage policy violation") |
| 1422 | + ) |
| 1423 | + |
| 1424 | + def test_content_policy_not_retriable(self): |
| 1425 | + assert not _is_retriable_connection_error( |
| 1426 | + RuntimeError("Content policy blocked") |
| 1427 | + ) |
| 1428 | + |
| 1429 | + def test_permission_denied_not_retriable(self): |
| 1430 | + assert not _is_retriable_connection_error( |
| 1431 | + RuntimeError("Permission denied for operation") |
| 1432 | + ) |
| 1433 | + |
| 1434 | + def test_generic_error_not_retriable(self): |
| 1435 | + # Generic errors without connection patterns should not be retried |
| 1436 | + assert not _is_retriable_connection_error(RuntimeError("Something went wrong")) |
| 1437 | + |
| 1438 | + |
| 1439 | +class TestACPPromptRetry: |
| 1440 | + """Test retry logic for ACP prompt failures.""" |
| 1441 | + |
| 1442 | + def _make_conversation_with_message(self, tmp_path, text="Hello"): |
| 1443 | + """Create a mock conversation with a user message.""" |
| 1444 | + state = _make_state(tmp_path) |
| 1445 | + state.events.append( |
| 1446 | + SystemPromptEvent( |
| 1447 | + source="agent", |
| 1448 | + system_prompt=TextContent(text="ACP-managed agent"), |
| 1449 | + tools=[], |
| 1450 | + ) |
| 1451 | + ) |
| 1452 | + state.events.append( |
| 1453 | + MessageEvent( |
| 1454 | + source="user", |
| 1455 | + llm_message=Message(role="user", content=[TextContent(text=text)]), |
| 1456 | + ) |
| 1457 | + ) |
| 1458 | + |
| 1459 | + conversation = MagicMock() |
| 1460 | + conversation.state = state |
| 1461 | + return conversation |
| 1462 | + |
| 1463 | + def test_retry_on_connection_error_then_success(self, tmp_path): |
| 1464 | + """Retry succeeds after transient connection error.""" |
| 1465 | + agent = _make_agent() |
| 1466 | + conversation = self._make_conversation_with_message(tmp_path) |
| 1467 | + events: list = [] |
| 1468 | + |
| 1469 | + mock_client = _OpenHandsACPBridge() |
| 1470 | + agent._client = mock_client |
| 1471 | + agent._conn = MagicMock() |
| 1472 | + agent._session_id = "test-session" |
| 1473 | + |
| 1474 | + call_count = 0 |
| 1475 | + |
| 1476 | + def _fake_run_async(_coro, **_kwargs): |
| 1477 | + nonlocal call_count |
| 1478 | + call_count += 1 |
| 1479 | + if call_count == 1: |
| 1480 | + raise ConnectionError("Connection reset by peer") |
| 1481 | + # Second call succeeds - must populate text and return a response |
| 1482 | + mock_client.accumulated_text.append("Success after retry") |
| 1483 | + # Return a mock PromptResponse (can be MagicMock since we only check usage) |
| 1484 | + return MagicMock(usage=None) |
| 1485 | + |
| 1486 | + mock_executor = MagicMock() |
| 1487 | + mock_executor.run_async = _fake_run_async |
| 1488 | + agent._executor = mock_executor |
| 1489 | + |
| 1490 | + # Patch sleep to avoid actual delays in tests |
| 1491 | + with patch("openhands.sdk.agent.acp_agent.time.sleep"): |
| 1492 | + agent.step(conversation, on_event=events.append) |
| 1493 | + |
| 1494 | + assert call_count == 2 # First failed, second succeeded |
| 1495 | + assert conversation.state.execution_status == ConversationExecutionStatus.FINISHED |
| 1496 | + assert len(events) == 3 # MessageEvent, ActionEvent, ObservationEvent |
| 1497 | + assert "Success after retry" in events[0].llm_message.content[0].text |
| 1498 | + |
| 1499 | + def test_no_retry_on_non_retriable_error(self, tmp_path): |
| 1500 | + """Non-retriable errors fail immediately without retry.""" |
| 1501 | + agent = _make_agent() |
| 1502 | + conversation = self._make_conversation_with_message(tmp_path) |
| 1503 | + events: list = [] |
| 1504 | + |
| 1505 | + mock_client = _OpenHandsACPBridge() |
| 1506 | + agent._client = mock_client |
| 1507 | + agent._conn = MagicMock() |
| 1508 | + agent._session_id = "test-session" |
| 1509 | + |
| 1510 | + call_count = 0 |
| 1511 | + |
| 1512 | + def _fake_run_async(_coro, **_kwargs): |
| 1513 | + nonlocal call_count |
| 1514 | + call_count += 1 |
| 1515 | + raise RuntimeError("Usage policy violation") |
| 1516 | + |
| 1517 | + mock_executor = MagicMock() |
| 1518 | + mock_executor.run_async = _fake_run_async |
| 1519 | + agent._executor = mock_executor |
| 1520 | + |
| 1521 | + with pytest.raises(RuntimeError, match="Usage policy violation"): |
| 1522 | + agent.step(conversation, on_event=events.append) |
| 1523 | + |
| 1524 | + assert call_count == 1 # No retry attempted |
| 1525 | + assert conversation.state.execution_status == ConversationExecutionStatus.ERROR |
| 1526 | + |
| 1527 | + def test_no_retry_on_timeout(self, tmp_path): |
| 1528 | + """Timeout errors are not retried (handled separately).""" |
| 1529 | + agent = _make_agent() |
| 1530 | + conversation = self._make_conversation_with_message(tmp_path) |
| 1531 | + |
| 1532 | + mock_client = _OpenHandsACPBridge() |
| 1533 | + agent._client = mock_client |
| 1534 | + agent._conn = MagicMock() |
| 1535 | + agent._session_id = "test-session" |
| 1536 | + |
| 1537 | + call_count = 0 |
| 1538 | + |
| 1539 | + def _fake_run_async(_coro, **_kwargs): |
| 1540 | + nonlocal call_count |
| 1541 | + call_count += 1 |
| 1542 | + raise TimeoutError("ACP prompt timed out") |
| 1543 | + |
| 1544 | + mock_executor = MagicMock() |
| 1545 | + mock_executor.run_async = _fake_run_async |
| 1546 | + agent._executor = mock_executor |
| 1547 | + |
| 1548 | + agent.step(conversation, on_event=lambda _: None) |
| 1549 | + |
| 1550 | + assert call_count == 1 # No retry for timeout |
| 1551 | + assert conversation.state.execution_status == ConversationExecutionStatus.ERROR |
| 1552 | + |
| 1553 | + def test_max_retries_exceeded(self, tmp_path): |
| 1554 | + """Error raised after max retries exhausted.""" |
| 1555 | + agent = _make_agent() |
| 1556 | + conversation = self._make_conversation_with_message(tmp_path) |
| 1557 | + events: list = [] |
| 1558 | + |
| 1559 | + mock_client = _OpenHandsACPBridge() |
| 1560 | + agent._client = mock_client |
| 1561 | + agent._conn = MagicMock() |
| 1562 | + agent._session_id = "test-session" |
| 1563 | + |
| 1564 | + call_count = 0 |
| 1565 | + |
| 1566 | + def _fake_run_async(_coro, **_kwargs): |
| 1567 | + nonlocal call_count |
| 1568 | + call_count += 1 |
| 1569 | + raise ConnectionError("Persistent connection failure") |
| 1570 | + |
| 1571 | + mock_executor = MagicMock() |
| 1572 | + mock_executor.run_async = _fake_run_async |
| 1573 | + agent._executor = mock_executor |
| 1574 | + |
| 1575 | + with patch("openhands.sdk.agent.acp_agent.time.sleep"): |
| 1576 | + with pytest.raises(ConnectionError, match="Persistent connection failure"): |
| 1577 | + agent.step(conversation, on_event=events.append) |
| 1578 | + |
| 1579 | + # Default max retries is 3, so 4 total attempts (1 initial + 3 retries) |
| 1580 | + assert call_count == 4 |
| 1581 | + assert conversation.state.execution_status == ConversationExecutionStatus.ERROR |
0 commit comments