Skip to content

Commit 2ac6d38

Browse files
authored
♻️Maintenance: Add UNKOWN type to DB State Type (🗃️) (#8284)
1 parent 8f372f6 commit 2ac6d38

File tree

6 files changed

+75
-22
lines changed

6 files changed

+75
-22
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""add state type unknown
2+
3+
Revision ID: 06eafd25d004
4+
Revises: ec4f62595e0c
5+
Create Date: 2025-09-01 12:25:25.617790+00:00
6+
7+
"""
8+
9+
import sqlalchemy as sa
10+
from alembic import op
11+
12+
# revision identifiers, used by Alembic.
13+
revision = "06eafd25d004"
14+
down_revision = "ec4f62595e0c"
15+
branch_labels = None
16+
depends_on = None
17+
18+
19+
def upgrade():
20+
op.execute("ALTER TYPE statetype ADD VALUE 'UNKNOWN'")
21+
22+
23+
def downgrade() -> None:
24+
# NOTE: PostgreSQL doesn't support removing enum values directly
25+
# This downgrades only ensure that StateType.UNKNOWN is not used
26+
#
27+
28+
# Find all tables and columns that use statetype enum
29+
result = op.get_bind().execute(
30+
sa.DDL(
31+
"""
32+
SELECT t.table_name, c.column_name, c.column_default
33+
FROM information_schema.columns c
34+
JOIN information_schema.tables t ON c.table_name = t.table_name
35+
WHERE c.udt_name = 'statetype'
36+
AND t.table_schema = 'public'
37+
"""
38+
)
39+
)
40+
41+
tables_columns = result.fetchall()
42+
43+
# Update UNKNOWN states to FAILED in all affected tables
44+
for table_name, column_name, _ in tables_columns:
45+
op.execute(
46+
sa.DDL(
47+
f"""
48+
UPDATE {table_name}
49+
SET {column_name} = 'FAILED'
50+
WHERE {column_name} = 'UNKNOWN'
51+
"""
52+
)
53+
)

packages/postgres-database/src/simcore_postgres_database/models/comp_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
""" Computational Pipeline Table
1+
"""Computational Pipeline Table"""
22

3-
"""
43
import enum
54
import uuid
65

@@ -24,6 +23,7 @@ class StateType(enum.Enum):
2423
ABORTED = "ABORTED"
2524
WAITING_FOR_RESOURCES = "WAITING_FOR_RESOURCES"
2625
WAITING_FOR_CLUSTER = "WAITING_FOR_CLUSTER"
26+
UNKNOWN = "UNKNOWN"
2727

2828

2929
def _new_uuid():

services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_runs.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,12 @@ def _handle_foreign_key_violation(
9595

9696
def _resolve_grouped_state(states: list[RunningState]) -> RunningState:
9797
# If any state is not a final state, return STARTED
98+
9899
final_states = {
99100
RunningState.FAILED,
100101
RunningState.ABORTED,
101102
RunningState.SUCCESS,
102-
RunningState.UNKNOWN,
103+
RunningState.UNKNOWN, # NOTE: this is NOT a final state, but happens when tasks are missing
103104
}
104105
if any(state not in final_states for state in states):
105106
return RunningState.STARTED
@@ -399,7 +400,6 @@ async def list_all_collection_run_ids_for_user_currently_running_computations(
399400
product_name: str,
400401
user_id: UserID,
401402
) -> list[CollectionRunID]:
402-
403403
list_query = (
404404
sa.select(
405405
comp_runs.c.collection_run_id,
@@ -493,17 +493,17 @@ async def list_group_by_collection_run_id(
493493
total_count = await conn.scalar(count_query)
494494
items = []
495495
async for row in await conn.stream(list_query):
496-
db_states = [DB_TO_RUNNING_STATE[s] for s in row["states"]]
496+
db_states = [DB_TO_RUNNING_STATE[s] for s in row.states]
497497
resolved_state = _resolve_grouped_state(db_states)
498498
items.append(
499499
ComputationCollectionRunRpcGet(
500-
collection_run_id=row["collection_run_id"],
501-
project_ids=row["project_ids"],
500+
collection_run_id=row.collection_run_id,
501+
project_ids=row.project_ids,
502502
state=resolved_state,
503-
info={} if row["info"] is None else row["info"],
504-
submitted_at=row["submitted_at"],
505-
started_at=row["started_at"],
506-
ended_at=row["ended_at"],
503+
info={} if row.info is None else row.info,
504+
submitted_at=row.submitted_at,
505+
started_at=row.started_at,
506+
ended_at=row.ended_at,
507507
)
508508
)
509509
return cast(int, total_count), items

services/director-v2/src/simcore_service_director_v2/utils/db.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,10 @@
1616
StateType.ABORTED: RunningState.ABORTED,
1717
StateType.WAITING_FOR_RESOURCES: RunningState.WAITING_FOR_RESOURCES,
1818
StateType.WAITING_FOR_CLUSTER: RunningState.WAITING_FOR_CLUSTER,
19+
StateType.UNKNOWN: RunningState.UNKNOWN,
1920
}
2021

21-
RUNNING_STATE_TO_DB = {v: k for k, v in DB_TO_RUNNING_STATE.items()} | {
22-
RunningState.UNKNOWN: StateType.FAILED
23-
}
22+
RUNNING_STATE_TO_DB = {v: k for k, v in DB_TO_RUNNING_STATE.items()}
2423

2524
_logger = logging.getLogger(__name__)
2625

services/director-v2/tests/unit/test_utils_computation.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -296,9 +296,9 @@ def test_get_pipeline_state_from_task_states(
296296
]
297297

298298
pipeline_state: RunningState = get_pipeline_state_from_task_states(tasks)
299-
assert (
300-
pipeline_state == exp_pipeline_state
301-
), f"task states are: {task_states}, got {pipeline_state} instead of {exp_pipeline_state}"
299+
assert pipeline_state == exp_pipeline_state, (
300+
f"task states are: {task_states}, got {pipeline_state} instead of {exp_pipeline_state}"
301+
)
302302

303303

304304
@pytest.mark.parametrize(
@@ -315,7 +315,7 @@ def test_get_pipeline_state_from_task_states(
315315
],
316316
)
317317
def test_is_pipeline_running(state, exp: bool):
318-
assert (
319-
is_pipeline_running(state) is exp
320-
), f"pipeline in {state}, i.e. running state should be {exp}"
318+
assert is_pipeline_running(state) is exp, (
319+
f"pipeline in {state}, i.e. running state should be {exp}"
320+
)
321321
assert is_pipeline_stopped is not exp

services/director-v2/tests/unit/with_dbs/comp_scheduler/test_db_repositories_comp_runs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -956,8 +956,9 @@ async def test_list_group_by_collection_run_id_with_unknown_returns_unknown(
956956
)
957957

958958
# Test the function
959+
assert "product_name" in run_metadata
959960
total_count, items = await repo.list_group_by_collection_run_id(
960-
product_name=run_metadata.get("product_name"),
961+
product_name=run_metadata["product_name"],
961962
user_id=published_project_1.user["id"],
962963
offset=0,
963964
limit=10,
@@ -967,7 +968,7 @@ async def test_list_group_by_collection_run_id_with_unknown_returns_unknown(
967968
assert total_count == 1
968969
assert len(items) == 1
969970
collection_item = items[0]
970-
assert collection_item.state == RunningState.FAILED
971+
assert collection_item.state == RunningState.UNKNOWN
971972

972973

973974
async def test_list_group_by_collection_run_id_with_project_filter(

0 commit comments

Comments
 (0)