Skip to content

Commit 07fa982

Browse files
committed
migrated to asyncpg
1 parent 405b7d7 commit 07fa982

File tree

1 file changed

+40
-72
lines changed
  • packages/simcore-sdk/src/simcore_sdk/node_ports_common

1 file changed

+40
-72
lines changed
Lines changed: 40 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,107 +1,77 @@
11
import json
22
import logging
3-
import os
4-
import socket
5-
from typing import Any
63

7-
import aiopg.sa
84
import sqlalchemy as sa
9-
import tenacity
10-
from aiopg.sa.engine import Engine
11-
from aiopg.sa.result import RowProxy
125
from models_library.projects import ProjectID
136
from models_library.users import UserID
14-
from servicelib.common_aiopg_utils import DataSourceName, create_pg_engine
15-
from servicelib.retry_policies import PostgresRetryPolicyUponInitialization
7+
from pydantic import TypeAdapter
8+
from servicelib.db_asyncpg_utils import create_async_engine_and_pg_database_ready
169
from settings_library.node_ports import NodePortsSettings
1710
from simcore_postgres_database.models.comp_tasks import comp_tasks
1811
from simcore_postgres_database.models.projects import projects
19-
from simcore_postgres_database.utils_aiopg import (
20-
close_engine,
21-
raise_if_migration_not_ready,
22-
)
23-
from sqlalchemy import and_
12+
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
2413

2514
from .exceptions import NodeNotFound, ProjectNotFoundError
2615

27-
log = logging.getLogger(__name__)
16+
_logger = logging.getLogger(__name__)
2817

2918

3019
async def _get_node_from_db(
31-
project_id: str, node_uuid: str, connection: aiopg.sa.SAConnection
32-
) -> RowProxy:
33-
log.debug(
20+
project_id: str, node_uuid: str, connection: AsyncConnection
21+
) -> sa.engine.Row:
22+
_logger.debug(
3423
"Reading from comp_tasks table for node uuid %s, project %s",
3524
node_uuid,
3625
project_id,
3726
)
27+
rows_count = await connection.scalar(
28+
sa.select(sa.func.count()).select_from(
29+
(comp_tasks.c.node_id == node_uuid)
30+
& (comp_tasks.c.project_id == project_id),
31+
)
32+
)
33+
if rows_count > 1:
34+
_logger.error("the node id %s is not unique", node_uuid)
3835
result = await connection.execute(
3936
sa.select(comp_tasks).where(
40-
and_(
41-
comp_tasks.c.node_id == node_uuid,
42-
comp_tasks.c.project_id == project_id,
43-
)
37+
(comp_tasks.c.node_id == node_uuid)
38+
& (comp_tasks.c.project_id == project_id)
4439
)
4540
)
46-
if result.rowcount > 1:
47-
log.error("the node id %s is not unique", node_uuid)
48-
node: RowProxy | None = await result.first()
41+
node = result.one_or_none()
4942
if not node:
50-
log.error("the node id %s was not found", node_uuid)
43+
_logger.error("the node id %s was not found", node_uuid)
5144
raise NodeNotFound(node_uuid)
5245
return node
5346

5447

55-
@tenacity.retry(**PostgresRetryPolicyUponInitialization().kwargs)
56-
async def _ensure_postgres_ready(dsn: DataSourceName) -> Engine:
57-
engine: aiopg.sa.Engine = await create_pg_engine(dsn, minsize=1, maxsize=4)
58-
try:
59-
await raise_if_migration_not_ready(engine)
60-
except Exception:
61-
await close_engine(engine)
62-
raise
63-
return engine
64-
65-
6648
class DBContextManager:
67-
def __init__(self, db_engine: aiopg.sa.Engine | None = None):
68-
self._db_engine: aiopg.sa.Engine | None = db_engine
49+
def __init__(self, db_engine: AsyncEngine | None = None) -> None:
50+
self._db_engine: AsyncEngine | None = db_engine
6951
self._db_engine_created: bool = False
7052

7153
@staticmethod
72-
async def _create_db_engine() -> aiopg.sa.Engine:
54+
async def _create_db_engine() -> AsyncEngine:
7355
settings = NodePortsSettings.create_from_envs()
74-
dsn = DataSourceName(
75-
application_name=f"{__name__}_{socket.gethostname()}_{os.getpid()}",
76-
database=settings.POSTGRES_SETTINGS.POSTGRES_DB,
77-
user=settings.POSTGRES_SETTINGS.POSTGRES_USER,
78-
password=settings.POSTGRES_SETTINGS.POSTGRES_PASSWORD.get_secret_value(),
79-
host=settings.POSTGRES_SETTINGS.POSTGRES_HOST,
80-
port=settings.POSTGRES_SETTINGS.POSTGRES_PORT,
56+
engine = await create_async_engine_and_pg_database_ready(
57+
settings.POSTGRES_SETTINGS
8158
)
82-
83-
engine: aiopg.sa.Engine = await _ensure_postgres_ready(dsn)
59+
assert isinstance(engine, AsyncEngine) # nosec
8460
return engine
8561

86-
async def __aenter__(self):
62+
async def __aenter__(self) -> AsyncEngine:
8763
if not self._db_engine:
8864
self._db_engine = await self._create_db_engine()
8965
self._db_engine_created = True
9066
return self._db_engine
9167

92-
async def __aexit__(self, exc_type, exc, tb):
68+
async def __aexit__(self, exc_type, exc, tb) -> None:
9369
if self._db_engine and self._db_engine_created:
94-
await close_engine(self._db_engine)
95-
log.debug(
96-
"engine '%s' after shutdown: closed=%s, size=%d",
97-
self._db_engine.dsn,
98-
self._db_engine.closed,
99-
self._db_engine.size,
100-
)
70+
await self._db_engine.dispose()
10171

10272

10373
class DBManager:
104-
def __init__(self, db_engine: aiopg.sa.Engine | None = None):
74+
def __init__(self, db_engine: AsyncEngine | None = None):
10575
self._db_engine = db_engine
10676

10777
async def write_ports_configuration(
@@ -111,20 +81,18 @@ async def write_ports_configuration(
11181
f"Writing port configuration to database for "
11282
f"project={project_id} node={node_uuid}: {json_configuration}"
11383
)
114-
log.debug(message)
84+
_logger.debug(message)
11585

11686
node_configuration = json.loads(json_configuration)
11787
async with DBContextManager(
11888
self._db_engine
119-
) as engine, engine.acquire() as connection:
89+
) as engine, engine.begin() as connection:
12090
# update the necessary parts
12191
await connection.execute(
12292
comp_tasks.update()
12393
.where(
124-
and_(
125-
comp_tasks.c.node_id == node_uuid,
126-
comp_tasks.c.project_id == project_id,
127-
)
94+
(comp_tasks.c.node_id == node_uuid)
95+
& (comp_tasks.c.project_id == project_id),
12896
)
12997
.values(
13098
schema=node_configuration["schema"],
@@ -137,13 +105,13 @@ async def write_ports_configuration(
137105
async def get_ports_configuration_from_node_uuid(
138106
self, project_id: str, node_uuid: str
139107
) -> str:
140-
log.debug(
108+
_logger.debug(
141109
"Getting ports configuration of node %s from comp_tasks table", node_uuid
142110
)
143111
async with DBContextManager(
144112
self._db_engine
145-
) as engine, engine.acquire() as connection:
146-
node: RowProxy = await _get_node_from_db(project_id, node_uuid, connection)
113+
) as engine, engine.connect() as connection:
114+
node = await _get_node_from_db(project_id, node_uuid, connection)
147115
node_json_config = json.dumps(
148116
{
149117
"schema": node.schema,
@@ -152,18 +120,18 @@ async def get_ports_configuration_from_node_uuid(
152120
"run_hash": node.run_hash,
153121
}
154122
)
155-
log.debug("Found and converted to json")
123+
_logger.debug("Found and converted to json")
156124
return node_json_config
157125

158126
async def get_project_owner_user_id(self, project_id: ProjectID) -> UserID:
159127
async with DBContextManager(
160128
self._db_engine
161-
) as engine, engine.acquire() as connection:
162-
prj_owner: Any | None = await connection.scalar(
129+
) as engine, engine.connect() as connection:
130+
prj_owner = await connection.scalar(
163131
sa.select(projects.c.prj_owner).where(
164132
projects.c.uuid == f"{project_id}"
165133
)
166134
)
167135
if prj_owner is None:
168136
raise ProjectNotFoundError(project_id)
169-
return UserID(prj_owner)
137+
return TypeAdapter(UserID).validate_python(prj_owner)

0 commit comments

Comments
 (0)