Skip to content

Commit a640f4c

Browse files
committed
test fixes
1 parent f272e6f commit a640f4c

File tree

3 files changed

+51
-36
lines changed

3 files changed

+51
-36
lines changed

packages/models-library/src/models_library/progress_bar.py

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Literal, TypeAlias
22

33
from pydantic import BaseModel, ConfigDict
4+
from pydantic.config import JsonDict
45

56
# NOTE: keep a list of possible unit, and please use correct official unit names
67
ProgressUnit: TypeAlias = Literal["Byte"]
@@ -13,34 +14,38 @@ class ProgressStructuredMessage(BaseModel):
1314
unit: str | None = None
1415
sub: "ProgressStructuredMessage | None" = None
1516

16-
model_config = ConfigDict(
17-
json_schema_extra={
18-
"examples": [
19-
{
20-
"description": "some description",
21-
"current": 12.2,
22-
"total": 123,
23-
},
24-
{
25-
"description": "some description",
26-
"current": 12.2,
27-
"total": 123,
28-
"unit": "Byte",
29-
},
30-
{
31-
"description": "downloading",
32-
"current": 2.0,
33-
"total": 5,
34-
"sub": {
35-
"description": "port 2",
17+
@staticmethod
18+
def _update_json_schema_extra(schema: JsonDict) -> None:
19+
schema.update(
20+
{
21+
"examples": [
22+
{
23+
"description": "some description",
24+
"current": 12.2,
25+
"total": 123,
26+
},
27+
{
28+
"description": "some description",
3629
"current": 12.2,
3730
"total": 123,
3831
"unit": "Byte",
3932
},
40-
},
41-
]
42-
}
43-
)
33+
{
34+
"description": "downloading",
35+
"current": 2.0,
36+
"total": 5,
37+
"sub": {
38+
"description": "port 2",
39+
"current": 12.2,
40+
"total": 123,
41+
"unit": "Byte",
42+
},
43+
},
44+
]
45+
}
46+
)
47+
48+
model_config = ConfigDict(json_schema_extra=_update_json_schema_extra)
4449

4550

4651
UNITLESS = None
@@ -96,7 +101,17 @@ def composed_message(self) -> str:
96101
{
97102
"actual_value": 0.3,
98103
"total": 1.0,
99-
"message": ProgressStructuredMessage.model_config["json_schema_extra"]["examples"][2], # type: ignore [index]
104+
"message": {
105+
"description": "downloading",
106+
"current": 2.0,
107+
"total": 5,
108+
"sub": {
109+
"description": "port 2",
110+
"current": 12.2,
111+
"total": 123,
112+
"unit": "Byte",
113+
},
114+
},
100115
},
101116
]
102117
},

services/api-server/tests/unit/api_functions/celery/test_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ async def test_celery_error_propagation(
304304
with pytest.raises(HTTPStatusError) as exc_info:
305305
await poll_task_until_done(client, auth, f"{task_uuid}")
306306

307-
assert exc_info.value.response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
307+
assert exc_info.value.response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
308308

309309

310310
@pytest.mark.parametrize(

services/api-server/tests/unit/test_tasks.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ async def test_get_task_result(
102102
None,
103103
None,
104104
None,
105-
status.HTTP_500_INTERNAL_SERVER_ERROR,
105+
status.HTTP_503_SERVICE_UNAVAILABLE,
106106
),
107107
(
108108
"GET",
@@ -111,7 +111,7 @@ async def test_get_task_result(
111111
CeleryError(),
112112
None,
113113
None,
114-
status.HTTP_500_INTERNAL_SERVER_ERROR,
114+
status.HTTP_503_SERVICE_UNAVAILABLE,
115115
),
116116
(
117117
"POST",
@@ -120,7 +120,7 @@ async def test_get_task_result(
120120
None,
121121
CeleryError(),
122122
None,
123-
status.HTTP_500_INTERNAL_SERVER_ERROR,
123+
status.HTTP_503_SERVICE_UNAVAILABLE,
124124
),
125125
(
126126
"GET",
@@ -129,7 +129,7 @@ async def test_get_task_result(
129129
CeleryError(),
130130
None,
131131
None,
132-
status.HTTP_500_INTERNAL_SERVER_ERROR,
132+
status.HTTP_503_SERVICE_UNAVAILABLE,
133133
),
134134
(
135135
"GET",
@@ -142,9 +142,9 @@ async def test_get_task_result(
142142
actual_value=0.5,
143143
total=1.0,
144144
unit="Byte",
145-
message=ProgressStructuredMessage.model_config["json_schema_extra"][
146-
"examples"
147-
][0],
145+
message=ProgressStructuredMessage.model_json_schema()["examples"][
146+
0
147+
],
148148
),
149149
),
150150
None,
@@ -162,9 +162,9 @@ async def test_get_task_result(
162162
actual_value=0.5,
163163
total=1.0,
164164
unit="Byte",
165-
message=ProgressStructuredMessage.model_config["json_schema_extra"][
166-
"examples"
167-
][0],
165+
message=ProgressStructuredMessage.model_json_schema()["examples"][
166+
0
167+
],
168168
),
169169
),
170170
None,

0 commit comments

Comments
 (0)