|
| 1 | +""" |
| 2 | +Tests for OetcHandler job polling and status monitoring. |
| 3 | +
|
| 4 | +This module tests the wait_and_get_job_data method which polls for job completion |
| 5 | +and handles various job states and error conditions. |
| 6 | +""" |
| 7 | + |
| 8 | +from datetime import datetime |
| 9 | +from unittest.mock import Mock, patch |
| 10 | + |
| 11 | +import pytest |
| 12 | +from requests import RequestException |
| 13 | + |
| 14 | +from linopy.remote.oetc import ( |
| 15 | + AuthenticationResult, |
| 16 | + ComputeProvider, |
| 17 | + OetcCredentials, |
| 18 | + OetcHandler, |
| 19 | + OetcSettings, |
| 20 | +) |
| 21 | + |
| 22 | + |
| 23 | +@pytest.fixture |
| 24 | +def mock_settings(): |
| 25 | + """Create mock settings for testing.""" |
| 26 | + credentials = OetcCredentials( email="[email protected]", password="test_password") |
| 27 | + return OetcSettings( |
| 28 | + credentials=credentials, |
| 29 | + name="Test Job", |
| 30 | + authentication_server_url="https://auth.example.com", |
| 31 | + orchestrator_server_url="https://orchestrator.example.com", |
| 32 | + compute_provider=ComputeProvider.GCP, |
| 33 | + ) |
| 34 | + |
| 35 | + |
| 36 | +@pytest.fixture |
| 37 | +def mock_auth_result(): |
| 38 | + """Create mock authentication result.""" |
| 39 | + return AuthenticationResult( |
| 40 | + token="mock_token", |
| 41 | + token_type="Bearer", |
| 42 | + expires_in=3600, |
| 43 | + authenticated_at=datetime.now(), |
| 44 | + ) |
| 45 | + |
| 46 | + |
| 47 | +@pytest.fixture |
| 48 | +def oetc_handler(mock_settings, mock_auth_result): |
| 49 | + """Create OetcHandler with mocked authentication.""" |
| 50 | + with patch( |
| 51 | + "linopy.remote.oetc.OetcHandler._OetcHandler__sign_in", |
| 52 | + return_value=mock_auth_result, |
| 53 | + ): |
| 54 | + with patch( |
| 55 | + "linopy.remote.oetc.OetcHandler._OetcHandler__get_cloud_provider_credentials" |
| 56 | + ): |
| 57 | + handler = OetcHandler(mock_settings) |
| 58 | + return handler |
| 59 | + |
| 60 | + |
| 61 | +class TestJobPollingSuccess: |
| 62 | + """Test successful job polling scenarios.""" |
| 63 | + |
| 64 | + def test_job_completes_immediately(self, oetc_handler): |
| 65 | + """Test job that completes on first poll.""" |
| 66 | + job_data = { |
| 67 | + "uuid": "job-123", |
| 68 | + "status": "FINISHED", |
| 69 | + "name": "test-job", |
| 70 | + "owner": "test-user", |
| 71 | + "solver": "highs", |
| 72 | + "duration_in_seconds": 120, |
| 73 | + "solving_duration_in_seconds": 90, |
| 74 | + "input_files": ["input.nc"], |
| 75 | + "output_files": ["output.nc"], |
| 76 | + "created_at": "2024-01-01T00:00:00Z", |
| 77 | + } |
| 78 | + |
| 79 | + with patch("requests.get") as mock_get: |
| 80 | + mock_response = Mock() |
| 81 | + mock_response.raise_for_status.return_value = None |
| 82 | + mock_response.json.return_value = job_data |
| 83 | + mock_get.return_value = mock_response |
| 84 | + |
| 85 | + result = oetc_handler.wait_and_get_job_data("job-123") |
| 86 | + |
| 87 | + assert result.uuid == "job-123" |
| 88 | + assert result.status == "FINISHED" |
| 89 | + assert result.output_files == ["output.nc"] |
| 90 | + mock_get.assert_called_once() |
| 91 | + |
| 92 | + def test_job_completes_with_no_output_files_warning(self, oetc_handler): |
| 93 | + """Test job completion with no output files generates warning.""" |
| 94 | + job_data = {"uuid": "job-123", "status": "FINISHED", "output_files": []} |
| 95 | + |
| 96 | + with patch("requests.get") as mock_get: |
| 97 | + with patch("linopy.remote.oetc.logger.warning") as mock_warning: |
| 98 | + mock_response = Mock() |
| 99 | + mock_response.raise_for_status.return_value = None |
| 100 | + mock_response.json.return_value = job_data |
| 101 | + mock_get.return_value = mock_response |
| 102 | + |
| 103 | + result = oetc_handler.wait_and_get_job_data("job-123") |
| 104 | + |
| 105 | + assert result.status == "FINISHED" |
| 106 | + mock_warning.assert_called_once_with( |
| 107 | + "OETC - Warning: Job completed but no output files found" |
| 108 | + ) |
| 109 | + |
| 110 | + @patch("time.sleep") # Mock sleep to speed up test |
| 111 | + def test_job_polling_progression(self, mock_sleep, oetc_handler): |
| 112 | + """Test job progresses through multiple states before completion.""" |
| 113 | + responses = [ |
| 114 | + {"uuid": "job-123", "status": "PENDING"}, |
| 115 | + {"uuid": "job-123", "status": "STARTING"}, |
| 116 | + {"uuid": "job-123", "status": "RUNNING", "duration_in_seconds": 30}, |
| 117 | + {"uuid": "job-123", "status": "RUNNING", "duration_in_seconds": 60}, |
| 118 | + {"uuid": "job-123", "status": "FINISHED", "output_files": ["output.nc"]}, |
| 119 | + ] |
| 120 | + |
| 121 | + with patch("requests.get") as mock_get: |
| 122 | + mock_response = Mock() |
| 123 | + mock_response.raise_for_status.return_value = None |
| 124 | + mock_response.json.side_effect = responses |
| 125 | + mock_get.return_value = mock_response |
| 126 | + |
| 127 | + result = oetc_handler.wait_and_get_job_data( |
| 128 | + "job-123", initial_poll_interval=1 |
| 129 | + ) |
| 130 | + |
| 131 | + assert result.status == "FINISHED" |
| 132 | + assert mock_get.call_count == 5 |
| 133 | + assert mock_sleep.call_count == 4 # Sleep called 4 times between 5 polls |
| 134 | + |
| 135 | + @patch("time.sleep") |
| 136 | + def test_polling_interval_backoff(self, mock_sleep, oetc_handler): |
| 137 | + """Test polling interval increases with exponential backoff.""" |
| 138 | + responses = [ |
| 139 | + {"uuid": "job-123", "status": "PENDING"}, |
| 140 | + {"uuid": "job-123", "status": "RUNNING"}, |
| 141 | + {"uuid": "job-123", "status": "FINISHED", "output_files": ["output.nc"]}, |
| 142 | + ] |
| 143 | + |
| 144 | + with patch("requests.get") as mock_get: |
| 145 | + mock_response = Mock() |
| 146 | + mock_response.raise_for_status.return_value = None |
| 147 | + mock_response.json.side_effect = responses |
| 148 | + mock_get.return_value = mock_response |
| 149 | + |
| 150 | + oetc_handler.wait_and_get_job_data( |
| 151 | + "job-123", initial_poll_interval=10, max_poll_interval=100 |
| 152 | + ) |
| 153 | + |
| 154 | + # Verify sleep was called with increasing intervals |
| 155 | + sleep_calls = [call[0][0] for call in mock_sleep.call_args_list] |
| 156 | + assert sleep_calls[0] == 10 # Initial interval |
| 157 | + assert sleep_calls[1] == 15 # 10 * 1.5 = 15 |
| 158 | + |
| 159 | + |
| 160 | +class TestJobPollingErrors: |
| 161 | + """Test job polling error scenarios.""" |
| 162 | + |
| 163 | + def test_setup_error_status(self, oetc_handler): |
| 164 | + """Test job with SETUP_ERROR status raises exception.""" |
| 165 | + job_data = {"uuid": "job-123", "status": "SETUP_ERROR"} |
| 166 | + |
| 167 | + with patch("requests.get") as mock_get: |
| 168 | + mock_response = Mock() |
| 169 | + mock_response.raise_for_status.return_value = None |
| 170 | + mock_response.json.return_value = job_data |
| 171 | + mock_get.return_value = mock_response |
| 172 | + |
| 173 | + with pytest.raises(Exception, match="Job failed during setup phase"): |
| 174 | + oetc_handler.wait_and_get_job_data("job-123") |
| 175 | + |
| 176 | + def test_runtime_error_status(self, oetc_handler): |
| 177 | + """Test job with RUNTIME_ERROR status raises exception.""" |
| 178 | + job_data = {"uuid": "job-123", "status": "RUNTIME_ERROR"} |
| 179 | + |
| 180 | + with patch("requests.get") as mock_get: |
| 181 | + mock_response = Mock() |
| 182 | + mock_response.raise_for_status.return_value = None |
| 183 | + mock_response.json.return_value = job_data |
| 184 | + mock_get.return_value = mock_response |
| 185 | + |
| 186 | + with pytest.raises(Exception, match="Job failed during execution"): |
| 187 | + oetc_handler.wait_and_get_job_data("job-123") |
| 188 | + |
| 189 | + def test_unknown_status_error(self, oetc_handler): |
| 190 | + """Test job with unknown status raises exception.""" |
| 191 | + job_data = {"uuid": "job-123", "status": "UNKNOWN_STATUS"} |
| 192 | + |
| 193 | + with patch("requests.get") as mock_get: |
| 194 | + mock_response = Mock() |
| 195 | + mock_response.raise_for_status.return_value = None |
| 196 | + mock_response.json.return_value = job_data |
| 197 | + mock_get.return_value = mock_response |
| 198 | + |
| 199 | + with pytest.raises(Exception, match="Unknown job status: UNKNOWN_STATUS"): |
| 200 | + oetc_handler.wait_and_get_job_data("job-123") |
| 201 | + |
| 202 | + |
| 203 | +class TestJobPollingNetworkErrors: |
| 204 | + """Test network error handling during job polling.""" |
| 205 | + |
| 206 | + @patch("time.sleep") |
| 207 | + def test_network_retry_success(self, mock_sleep, oetc_handler): |
| 208 | + """Test network errors are retried and eventually succeed.""" |
| 209 | + successful_response = { |
| 210 | + "uuid": "job-123", |
| 211 | + "status": "FINISHED", |
| 212 | + "output_files": ["output.nc"], |
| 213 | + } |
| 214 | + |
| 215 | + with patch("requests.get") as mock_get: |
| 216 | + # First two calls fail, third succeeds |
| 217 | + mock_get.side_effect = [ |
| 218 | + RequestException("Network error 1"), |
| 219 | + RequestException("Network error 2"), |
| 220 | + Mock( |
| 221 | + raise_for_status=Mock(), json=Mock(return_value=successful_response) |
| 222 | + ), |
| 223 | + ] |
| 224 | + |
| 225 | + result = oetc_handler.wait_and_get_job_data("job-123") |
| 226 | + |
| 227 | + assert result.status == "FINISHED" |
| 228 | + assert mock_get.call_count == 3 |
| 229 | + assert mock_sleep.call_count == 2 # Retry delays |
| 230 | + |
| 231 | + # Verify retry delays increase |
| 232 | + sleep_calls = [call[0][0] for call in mock_sleep.call_args_list] |
| 233 | + assert sleep_calls[0] == 10 # First retry: 1 * 10 = 10 |
| 234 | + assert sleep_calls[1] == 20 # Second retry: 2 * 10 = 20 |
| 235 | + |
| 236 | + @patch("time.sleep") |
| 237 | + def test_max_network_retries_exceeded(self, mock_sleep, oetc_handler): |
| 238 | + """Test max network retries causes exception.""" |
| 239 | + with patch("requests.get") as mock_get: |
| 240 | + # All calls fail with RequestException |
| 241 | + mock_get.side_effect = RequestException("Network error") |
| 242 | + |
| 243 | + with pytest.raises( |
| 244 | + Exception, match="Failed to get job status after 10 network retries" |
| 245 | + ): |
| 246 | + oetc_handler.wait_and_get_job_data("job-123") |
| 247 | + |
| 248 | + # Should retry exactly 10 times before failing |
| 249 | + assert mock_get.call_count == 10 |
| 250 | + |
| 251 | + @patch("time.sleep") |
| 252 | + def test_network_retry_delay_cap(self, mock_sleep, oetc_handler): |
| 253 | + """Test network retry delay is capped at 60 seconds.""" |
| 254 | + with patch("requests.get") as mock_get: |
| 255 | + mock_get.side_effect = RequestException("Network error") |
| 256 | + |
| 257 | + with pytest.raises(Exception): |
| 258 | + oetc_handler.wait_and_get_job_data("job-123") |
| 259 | + |
| 260 | + # Check that delay is capped at 60 seconds |
| 261 | + sleep_calls = [call[0][0] for call in mock_sleep.call_args_list] |
| 262 | + assert all(delay <= 60 for delay in sleep_calls) |
| 263 | + |
| 264 | + def test_keyerror_in_response(self, oetc_handler): |
| 265 | + """Test KeyError in response parsing raises exception.""" |
| 266 | + with patch("requests.get") as mock_get: |
| 267 | + mock_response = Mock() |
| 268 | + mock_response.raise_for_status.return_value = None |
| 269 | + mock_response.json.return_value = {} # Missing required 'uuid' field |
| 270 | + mock_get.return_value = mock_response |
| 271 | + |
| 272 | + with pytest.raises( |
| 273 | + Exception, match="Invalid job status response format: missing field" |
| 274 | + ): |
| 275 | + oetc_handler.wait_and_get_job_data("job-123") |
| 276 | + |
| 277 | + def test_generic_exception_handling(self, oetc_handler): |
| 278 | + """Test generic exception handling in polling loop.""" |
| 279 | + with patch("requests.get") as mock_get: |
| 280 | + mock_response = Mock() |
| 281 | + mock_response.raise_for_status.side_effect = ValueError("Unexpected error") |
| 282 | + mock_get.return_value = mock_response |
| 283 | + |
| 284 | + with pytest.raises( |
| 285 | + Exception, match="Error getting job status: Unexpected error" |
| 286 | + ): |
| 287 | + oetc_handler.wait_and_get_job_data("job-123") |
| 288 | + |
| 289 | + def test_status_error_exception_preserved(self, oetc_handler): |
| 290 | + """Test that status-related exceptions are preserved.""" |
| 291 | + with patch("requests.get") as mock_get: |
| 292 | + # Simulate an exception that mentions "status:" - should be re-raised as-is |
| 293 | + mock_response = Mock() |
| 294 | + mock_response.raise_for_status.side_effect = Exception( |
| 295 | + "Custom status: error" |
| 296 | + ) |
| 297 | + mock_get.return_value = mock_response |
| 298 | + |
| 299 | + with pytest.raises(Exception, match="Custom status: error"): |
| 300 | + oetc_handler.wait_and_get_job_data("job-123") |
| 301 | + |
| 302 | + def test_oetc_logs_exception_preserved(self, oetc_handler): |
| 303 | + """Test that OETC logs exceptions are preserved.""" |
| 304 | + with patch("requests.get") as mock_get: |
| 305 | + # Simulate an exception that mentions "OETC logs" - should be re-raised as-is |
| 306 | + mock_response = Mock() |
| 307 | + mock_response.raise_for_status.side_effect = Exception( |
| 308 | + "Check the OETC logs for details" |
| 309 | + ) |
| 310 | + mock_get.return_value = mock_response |
| 311 | + |
| 312 | + with pytest.raises(Exception, match="Check the OETC logs for details"): |
| 313 | + oetc_handler.wait_and_get_job_data("job-123") |
0 commit comments