|
| 1 | +from typing import Generator |
| 2 | +from unittest.mock import Mock, patch |
| 3 | + |
| 4 | +import msrest.exceptions |
| 5 | +import pytest |
| 6 | +from azure.batch.custom.custom_errors import CreateTasksErrorException |
| 7 | +from azure.batch.models import BatchError, BatchErrorDetail, ErrorMessage, TaskAddResult |
| 8 | + |
| 9 | +from pctasks.run.batch.client import BatchClient |
| 10 | +from pctasks.run.batch.task import BatchTask |
| 11 | +from pctasks.run.settings import BatchSettings |
| 12 | + |
| 13 | + |
| 14 | +class CustomExceptionWithMessage(Exception): |
| 15 | + """Custom exception with message attribute for testing.""" |
| 16 | + |
| 17 | + def __init__(self, message: str): |
| 18 | + self.message = message |
| 19 | + super().__init__(message) |
| 20 | + |
| 21 | + |
| 22 | +@pytest.fixture |
| 23 | +def mock_settings() -> Mock: |
| 24 | + mock_settings = Mock(spec=BatchSettings) |
| 25 | + mock_settings.url = "https://test.batch.azure.com" |
| 26 | + mock_settings.key = "test_key" |
| 27 | + mock_settings.default_pool_id = "test_pool" |
| 28 | + mock_settings.submit_threads = 4 |
| 29 | + mock_settings.cache_seconds = 5 |
| 30 | + mock_settings.get_batch_name = Mock(return_value="test") |
| 31 | + return mock_settings |
| 32 | + |
| 33 | + |
| 34 | +@pytest.fixture |
| 35 | +def batch_client(mock_settings: Mock) -> BatchClient: |
| 36 | + client = BatchClient(mock_settings) |
| 37 | + return client |
| 38 | + |
| 39 | + |
| 40 | +@pytest.fixture |
| 41 | +def mock_batch_service_client(batch_client: BatchClient) -> Generator[Mock, None, None]: |
| 42 | + mock_client = Mock() |
| 43 | + with patch.object(batch_client, "_ensure_client", return_value=mock_client): |
| 44 | + yield mock_client |
| 45 | + |
| 46 | + |
| 47 | +@pytest.fixture |
| 48 | +def mock_task(mock_batch_service_client: Mock) -> Mock: |
| 49 | + mock_task_api = Mock() |
| 50 | + mock_batch_service_client.task = mock_task_api |
| 51 | + return mock_task_api |
| 52 | + |
| 53 | + |
| 54 | +@pytest.fixture |
| 55 | +def mock_error() -> Mock: |
| 56 | + error = Mock(spec=BatchError) |
| 57 | + error.code = "test_error" |
| 58 | + error.message = None |
| 59 | + error.values = None |
| 60 | + return error |
| 61 | + |
| 62 | + |
| 63 | +@pytest.fixture |
| 64 | +def mock_task_add_result(mock_error: Mock) -> Mock: |
| 65 | + result = Mock(spec=TaskAddResult) |
| 66 | + result.error = mock_error |
| 67 | + result.task_id = "test-task-id" |
| 68 | + return result |
| 69 | + |
| 70 | + |
| 71 | +@pytest.fixture |
| 72 | +def create_tasks_exception(mock_task_add_result: Mock) -> CreateTasksErrorException: |
| 73 | + return CreateTasksErrorException( |
| 74 | + "ClientError", |
| 75 | + errors=[], |
| 76 | + failure_tasks=[mock_task_add_result], |
| 77 | + ) |
| 78 | + |
| 79 | + |
| 80 | +@pytest.fixture |
| 81 | +def mock_tasks() -> list[Mock]: |
| 82 | + tasks = [Mock(spec=BatchTask) for _ in range(3)] |
| 83 | + for task in tasks: |
| 84 | + task.to_params = Mock(return_value={}) |
| 85 | + return tasks |
| 86 | + |
| 87 | + |
| 88 | +def test_add_collection_handles_missing_message( |
| 89 | + batch_client: BatchClient, |
| 90 | + mock_task: Mock, |
| 91 | + create_tasks_exception: CreateTasksErrorException, |
| 92 | + mock_tasks: list[Mock], |
| 93 | +) -> None: |
| 94 | + mock_task.add_collection.side_effect = create_tasks_exception |
| 95 | + with pytest.raises(CreateTasksErrorException) as excinfo: |
| 96 | + batch_client.add_collection("test-job-id", mock_tasks) |
| 97 | + |
| 98 | + assert excinfo.value is create_tasks_exception |
| 99 | + |
| 100 | + |
| 101 | +@pytest.mark.parametrize( |
| 102 | + "error_type, error_args, expected_code, expected_message", |
| 103 | + [ |
| 104 | + ( |
| 105 | + BatchError, |
| 106 | + { |
| 107 | + "code": "BatchErrorCode", |
| 108 | + "message": ErrorMessage(value="Batch error message"), |
| 109 | + }, |
| 110 | + "BatchErrorCode", |
| 111 | + "Batch error message", |
| 112 | + ), |
| 113 | + ( |
| 114 | + msrest.exceptions.ClientRequestError, |
| 115 | + ["Client request failed"], |
| 116 | + "ClientRequestError", |
| 117 | + "Client request failed", |
| 118 | + ), |
| 119 | + (ValueError, ["Invalid value"], "ValueError", "Invalid value"), |
| 120 | + ( |
| 121 | + CustomExceptionWithMessage, |
| 122 | + ["Custom error message"], |
| 123 | + "CustomExceptionWithMessage", |
| 124 | + "Custom error message", |
| 125 | + ), |
| 126 | + (Exception, [], "Exception", ""), |
| 127 | + ( |
| 128 | + BatchError, |
| 129 | + { |
| 130 | + "code": "ValueError", |
| 131 | + "message": ErrorMessage(value="Error with details"), |
| 132 | + "values": [BatchErrorDetail(key="detail1", value="value1")], |
| 133 | + }, |
| 134 | + "ValueError", |
| 135 | + "Error with details", |
| 136 | + ), |
| 137 | + ], |
| 138 | +) |
| 139 | +def test_to_batch_error_conversion( |
| 140 | + batch_client: BatchClient, |
| 141 | + error_type: type, |
| 142 | + error_args: list, |
| 143 | + expected_code: str, |
| 144 | + expected_message: str, |
| 145 | +) -> None: |
| 146 | + if isinstance(error_args, dict): |
| 147 | + error = error_type(**error_args) |
| 148 | + else: |
| 149 | + error = error_type(*error_args) |
| 150 | + |
| 151 | + batch_error = batch_client._to_batch_error(error) |
| 152 | + |
| 153 | + assert batch_error.code == expected_code |
| 154 | + |
| 155 | + if hasattr(batch_error.message, "value"): |
| 156 | + assert batch_error.message.value == expected_message |
| 157 | + else: |
| 158 | + assert str(batch_error.message) == expected_message |
| 159 | + |
| 160 | + if hasattr(error, "values") and error.values: |
| 161 | + assert batch_error.values == error.values |
| 162 | + |
| 163 | + |
| 164 | +def test_add_collection_handles_different_error_types( |
| 165 | + batch_client: BatchClient, |
| 166 | + mock_task: Mock, |
| 167 | + mock_tasks: list[Mock], |
| 168 | +) -> None: |
| 169 | + client_request_error = msrest.exceptions.ClientRequestError("Network failure") |
| 170 | + |
| 171 | + batch_error = BatchError( |
| 172 | + code="BatchErrorCode", message=ErrorMessage(value="Batch processing error") |
| 173 | + ) |
| 174 | + |
| 175 | + task_result1 = Mock(spec=TaskAddResult) |
| 176 | + task_result1.task_id = "task-1" |
| 177 | + task_result1.error = client_request_error |
| 178 | + |
| 179 | + task_result2 = Mock(spec=TaskAddResult) |
| 180 | + task_result2.task_id = "task-2" |
| 181 | + task_result2.error = batch_error |
| 182 | + |
| 183 | + mixed_exception = CreateTasksErrorException( |
| 184 | + pending_tasks=[], |
| 185 | + failure_tasks=[task_result1, task_result2], |
| 186 | + errors=[ValueError("Another error")], |
| 187 | + ) |
| 188 | + |
| 189 | + mock_task.add_collection.side_effect = mixed_exception |
| 190 | + |
| 191 | + with pytest.raises(CreateTasksErrorException) as excinfo: |
| 192 | + batch_client.add_collection("test-job-id", mock_tasks) |
| 193 | + |
| 194 | + assert excinfo.value is mixed_exception |
0 commit comments