Skip to content

Commit c3e7129

Browse files
author
Andrei Neagu
committed
fixed broken
1 parent 888e10f commit c3e7129

File tree

2 files changed

+69
-15
lines changed

2 files changed

+69
-15
lines changed

packages/models-library/src/models_library/api_schemas_long_running_tasks/tasks.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from datetime import datetime
33
from typing import Any
44

5-
from pydantic import BaseModel, field_validator
5+
from pydantic import BaseModel, model_validator
66

77
from .base import TaskId, TaskProgress
88

@@ -20,24 +20,22 @@ class TaskResult(BaseModel):
2020

2121
class TaskBase(BaseModel):
2222
task_id: TaskId
23-
24-
# NOTE: task name can always be extraced from the task_id
25-
# since it'e encoded inside it (expect when this is ued
26-
# with data coming form the celery tasks)
2723
task_name: str = ""
2824

29-
@field_validator("task_name", mode="before")
30-
@classmethod
31-
def populate_task_name_if_not_provided(cls, task_name: str, info):
32-
# attempt to extract the task name from the task_id
33-
# if this is coming form a long_running_task
34-
task_id = info.data["task_id"]
35-
if task_id and task_name == "":
36-
parts = task_id.split(".")
25+
@model_validator(mode="after")
26+
def try_populate_task_name_from_task_id(self) -> "TaskBase":
27+
# NOTE: currently this model is used to validate tasks coming from
28+
# the celery backend and form long_running_tasks
29+
# 1. if a task comes from Celery, it will keep it's given name
30+
# 2. if a task comes from long_running_tasks, it will extract it form
31+
# the task_id, which looks like "{PREFIX}.{TASK_NAME}.UNIQUE|{UUID}"
32+
33+
if self.task_id and self.task_name == "":
34+
parts = self.task_id.split(".")
3735
if len(parts) > 1:
38-
task_name = urllib.parse.unquote(parts[1])
36+
self.task_name = urllib.parse.unquote(parts[1])
3937

40-
return task_name
38+
return self
4139

4240

4341
class TaskGet(TaskBase):
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import pytest
2+
from models_library.api_schemas_long_running_tasks.tasks import TaskGet
3+
from pydantic import TypeAdapter
4+
5+
6+
def _get_data_withut_task_name(task_id: str) -> dict:
7+
return {
8+
"task_id": task_id,
9+
"status_href": "",
10+
"result_href": "",
11+
"abort_href": "",
12+
}
13+
14+
15+
@pytest.mark.parametrize(
16+
"data, expected_task_name",
17+
[
18+
(_get_data_withut_task_name("a.b.c.d"), "b"),
19+
(_get_data_withut_task_name("a.b.c"), "b"),
20+
(_get_data_withut_task_name("a.b"), "b"),
21+
(_get_data_withut_task_name("a"), ""),
22+
],
23+
)
24+
def test_try_extract_task_name(data: dict, expected_task_name: str) -> None:
25+
task_get = TaskGet(**data)
26+
assert task_get.task_name == expected_task_name
27+
28+
task_get = TypeAdapter(TaskGet).validate_python(data)
29+
assert task_get.task_name == expected_task_name
30+
31+
32+
def _get_data_with_task_name(task_id: str, task_name: str) -> dict:
33+
return {
34+
"task_id": task_id,
35+
"task_name": task_name,
36+
"status_href": "",
37+
"result_href": "",
38+
"abort_href": "",
39+
}
40+
41+
42+
@pytest.mark.parametrize(
43+
"data, expected_task_name",
44+
[
45+
(_get_data_with_task_name("a.b.c.d", "a_name"), "a_name"),
46+
(_get_data_with_task_name("a.b.c", "a_name"), "a_name"),
47+
(_get_data_with_task_name("a.b", "a_name"), "a_name"),
48+
(_get_data_with_task_name("a", "a_name"), "a_name"),
49+
],
50+
)
51+
def test_task_name_is_provided(data: dict, expected_task_name: str) -> None:
52+
task_get = TaskGet(**data)
53+
assert task_get.task_name == expected_task_name
54+
55+
task_get = TypeAdapter(TaskGet).validate_python(data)
56+
assert task_get.task_name == expected_task_name

0 commit comments

Comments
 (0)