Skip to content

Commit 230fd0f

Browse files
authored
fix - batch message.error program crash when message is None (#334)
* fix - batch message.error program crash when message is None * add serializer func for converting any Exception to a BatchError format properly; update unit tests * isort
1 parent 5e16216 commit 230fd0f

File tree

2 files changed

+258
-23
lines changed

2 files changed

+258
-23
lines changed

pctasks/run/pctasks/run/batch/client.py

Lines changed: 64 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import urllib3.exceptions
1111
from azure.batch import BatchServiceClient
1212
from azure.batch.custom.custom_errors import CreateTasksErrorException
13+
from azure.batch.models import BatchError, BatchErrorDetail, ErrorMessage
1314
from dateutil.tz import tzutc
1415
from requests import Response
1516

@@ -123,6 +124,45 @@ def _mo(j: batchmodels.CloudJob) -> str:
123124

124125
return map_opt(_mo, result)
125126

127+
def _to_batch_error(self, error: Exception) -> BatchError:
128+
"""
129+
Convert any exception to BatchError format.
130+
131+
:param error: Any exception that needs to be converted to BatchError
132+
:type error: Exception
133+
:return: A BatchError representation of the exception
134+
:rtype: ~azure.batch.models.BatchError
135+
:raises AttributeError: If the ErrorMessage object cannot be properly created
136+
137+
Example:
138+
139+
```python
140+
try:
141+
# Some batch operation
142+
except Exception as e:
143+
batch_error = _to_batch_error(e)
144+
logger.error(f"Operation failed: {batch_error.message}")
145+
```
146+
"""
147+
148+
code: str = getattr(error, "code", type(error).__name__)
149+
150+
if hasattr(error, "message"):
151+
message = error.message
152+
else:
153+
message = str(error)
154+
155+
if isinstance(message, ErrorMessage):
156+
error_message = message
157+
else:
158+
error_message = ErrorMessage(value=message)
159+
160+
values: List[BatchErrorDetail] = []
161+
if hasattr(error, "values"):
162+
values = cast(List[BatchErrorDetail], error.values)
163+
164+
return BatchError(code=code, message=error_message, values=values)
165+
126166
def get_job(self, job_id: str) -> Optional[batchmodels.CloudJob]:
127167
client = self._ensure_client()
128168
try:
@@ -210,33 +250,35 @@ def add_collection(
210250
params = [task.to_params() for task in tasks]
211251
try:
212252
result: batchmodels.TaskAddCollectionResult = self._with_backoff(
213-
lambda: cast(
214-
batchmodels.TaskAddCollectionResult,
215-
client.task.add_collection(
216-
job_id=job_id,
217-
value=params,
218-
threads=self.settings.submit_threads,
219-
),
253+
lambda: client.task.add_collection(
254+
job_id=job_id,
255+
value=params,
256+
threads=self.settings.submit_threads,
220257
)
221258
)
222259
task_results: List[batchmodels.TaskAddResult] = result.value # type: ignore
223260
except CreateTasksErrorException as e:
224-
logger.warn("Failed to add tasks...")
261+
logger.error("Failed to add tasks...")
262+
225263
for exc in e.errors:
226-
logger.warn(" -- RETURNED EXCEPTION --")
227-
logger.exception(exc)
228-
for failure_task in e.failure_tasks:
229-
task_add_result = cast(batchmodels.TaskAddResult, failure_task)
230-
error = cast(batchmodels.BatchError, task_add_result.error)
231-
if error:
232-
logger.error(
233-
f"Task {task_add_result.task_id} failed with error: "
234-
f"{error.message}"
235-
)
236-
error_details = cast(batchmodels.BatchError, error).values
237-
if error_details:
238-
for detail in error_details:
239-
logger.error(f" - {detail.key}: {detail.value}")
264+
exc = cast(Exception, exc)
265+
logger.error(" -- RETURNED EXCEPTION --")
266+
logger.error(exc)
267+
268+
for task_add_result in e.failure_tasks:
269+
task_add_result = cast(batchmodels.TaskAddResult, task_add_result)
270+
271+
if task_add_result.error is None:
272+
continue
273+
274+
batch_error = self._to_batch_error(task_add_result.error)
275+
logger.error(
276+
f"Failed to create task {task_add_result.task_id} with error: {batch_error.message.value}" # noqa: E501
277+
)
278+
279+
if batch_error.values:
280+
for detail in batch_error.values:
281+
logger.error(f" - {detail.key}: {detail.value}")
240282
raise
241283
return [r.error for r in task_results]
242284

@@ -376,7 +418,6 @@ def restart_silent_tasks(
376418
print(f"{task.id} - {last_modified}")
377419

378420
except batchmodels.BatchErrorException:
379-
380421
# stdout.txt doesn't exist
381422
# Check if it's been running without output for
382423
# the max time.
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
from typing import Generator
2+
from unittest.mock import Mock, patch
3+
4+
import msrest.exceptions
5+
import pytest
6+
from azure.batch.custom.custom_errors import CreateTasksErrorException
7+
from azure.batch.models import BatchError, BatchErrorDetail, ErrorMessage, TaskAddResult
8+
9+
from pctasks.run.batch.client import BatchClient
10+
from pctasks.run.batch.task import BatchTask
11+
from pctasks.run.settings import BatchSettings
12+
13+
14+
class CustomExceptionWithMessage(Exception):
15+
"""Custom exception with message attribute for testing."""
16+
17+
def __init__(self, message: str):
18+
self.message = message
19+
super().__init__(message)
20+
21+
22+
@pytest.fixture
23+
def mock_settings() -> Mock:
24+
mock_settings = Mock(spec=BatchSettings)
25+
mock_settings.url = "https://test.batch.azure.com"
26+
mock_settings.key = "test_key"
27+
mock_settings.default_pool_id = "test_pool"
28+
mock_settings.submit_threads = 4
29+
mock_settings.cache_seconds = 5
30+
mock_settings.get_batch_name = Mock(return_value="test")
31+
return mock_settings
32+
33+
34+
@pytest.fixture
35+
def batch_client(mock_settings: Mock) -> BatchClient:
36+
client = BatchClient(mock_settings)
37+
return client
38+
39+
40+
@pytest.fixture
41+
def mock_batch_service_client(batch_client: BatchClient) -> Generator[Mock, None, None]:
42+
mock_client = Mock()
43+
with patch.object(batch_client, "_ensure_client", return_value=mock_client):
44+
yield mock_client
45+
46+
47+
@pytest.fixture
48+
def mock_task(mock_batch_service_client: Mock) -> Mock:
49+
mock_task_api = Mock()
50+
mock_batch_service_client.task = mock_task_api
51+
return mock_task_api
52+
53+
54+
@pytest.fixture
55+
def mock_error() -> Mock:
56+
error = Mock(spec=BatchError)
57+
error.code = "test_error"
58+
error.message = None
59+
error.values = None
60+
return error
61+
62+
63+
@pytest.fixture
64+
def mock_task_add_result(mock_error: Mock) -> Mock:
65+
result = Mock(spec=TaskAddResult)
66+
result.error = mock_error
67+
result.task_id = "test-task-id"
68+
return result
69+
70+
71+
@pytest.fixture
72+
def create_tasks_exception(mock_task_add_result: Mock) -> CreateTasksErrorException:
73+
return CreateTasksErrorException(
74+
"ClientError",
75+
errors=[],
76+
failure_tasks=[mock_task_add_result],
77+
)
78+
79+
80+
@pytest.fixture
81+
def mock_tasks() -> list[Mock]:
82+
tasks = [Mock(spec=BatchTask) for _ in range(3)]
83+
for task in tasks:
84+
task.to_params = Mock(return_value={})
85+
return tasks
86+
87+
88+
def test_add_collection_handles_missing_message(
89+
batch_client: BatchClient,
90+
mock_task: Mock,
91+
create_tasks_exception: CreateTasksErrorException,
92+
mock_tasks: list[Mock],
93+
) -> None:
94+
mock_task.add_collection.side_effect = create_tasks_exception
95+
with pytest.raises(CreateTasksErrorException) as excinfo:
96+
batch_client.add_collection("test-job-id", mock_tasks)
97+
98+
assert excinfo.value is create_tasks_exception
99+
100+
101+
@pytest.mark.parametrize(
102+
"error_type, error_args, expected_code, expected_message",
103+
[
104+
(
105+
BatchError,
106+
{
107+
"code": "BatchErrorCode",
108+
"message": ErrorMessage(value="Batch error message"),
109+
},
110+
"BatchErrorCode",
111+
"Batch error message",
112+
),
113+
(
114+
msrest.exceptions.ClientRequestError,
115+
["Client request failed"],
116+
"ClientRequestError",
117+
"Client request failed",
118+
),
119+
(ValueError, ["Invalid value"], "ValueError", "Invalid value"),
120+
(
121+
CustomExceptionWithMessage,
122+
["Custom error message"],
123+
"CustomExceptionWithMessage",
124+
"Custom error message",
125+
),
126+
(Exception, [], "Exception", ""),
127+
(
128+
BatchError,
129+
{
130+
"code": "ValueError",
131+
"message": ErrorMessage(value="Error with details"),
132+
"values": [BatchErrorDetail(key="detail1", value="value1")],
133+
},
134+
"ValueError",
135+
"Error with details",
136+
),
137+
],
138+
)
139+
def test_to_batch_error_conversion(
140+
batch_client: BatchClient,
141+
error_type: type,
142+
error_args: list,
143+
expected_code: str,
144+
expected_message: str,
145+
) -> None:
146+
if isinstance(error_args, dict):
147+
error = error_type(**error_args)
148+
else:
149+
error = error_type(*error_args)
150+
151+
batch_error = batch_client._to_batch_error(error)
152+
153+
assert batch_error.code == expected_code
154+
155+
if hasattr(batch_error.message, "value"):
156+
assert batch_error.message.value == expected_message
157+
else:
158+
assert str(batch_error.message) == expected_message
159+
160+
if hasattr(error, "values") and error.values:
161+
assert batch_error.values == error.values
162+
163+
164+
def test_add_collection_handles_different_error_types(
165+
batch_client: BatchClient,
166+
mock_task: Mock,
167+
mock_tasks: list[Mock],
168+
) -> None:
169+
client_request_error = msrest.exceptions.ClientRequestError("Network failure")
170+
171+
batch_error = BatchError(
172+
code="BatchErrorCode", message=ErrorMessage(value="Batch processing error")
173+
)
174+
175+
task_result1 = Mock(spec=TaskAddResult)
176+
task_result1.task_id = "task-1"
177+
task_result1.error = client_request_error
178+
179+
task_result2 = Mock(spec=TaskAddResult)
180+
task_result2.task_id = "task-2"
181+
task_result2.error = batch_error
182+
183+
mixed_exception = CreateTasksErrorException(
184+
pending_tasks=[],
185+
failure_tasks=[task_result1, task_result2],
186+
errors=[ValueError("Another error")],
187+
)
188+
189+
mock_task.add_collection.side_effect = mixed_exception
190+
191+
with pytest.raises(CreateTasksErrorException) as excinfo:
192+
batch_client.add_collection("test-job-id", mock_tasks)
193+
194+
assert excinfo.value is mixed_exception

0 commit comments

Comments
 (0)